From 707f070a3cf570019fa8c2fb05d328c4a34ca0e9 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sat, 1 Dec 2018 12:52:41 +0100 Subject: [PATCH 01/20] build: fix dirty detection at the end of release build was using Bashisms --- Makefile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index c6fc773..a0f31a6 100644 --- a/Makefile +++ b/Makefile @@ -133,8 +133,9 @@ release: $(RELEASE_BINS) $(RELEASE_NOARCH) cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt @# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override @if git describe --dirty 2>/dev/null | grep dirty >/dev/null; then \ - if [ "$(ZREPL_VERSION)" == "" ]; then \ - echo "[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override"; \ + echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \ + if [ "$(ZREPL_VERSION)" = "" ]; then \ + echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \ exit 1; \ fi; \ fi; From 3535b251ab0a7f7fb5535ed6b51d3ddb88f39540 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sat, 1 Dec 2018 14:09:32 +0100 Subject: [PATCH 02/20] freeze Go build dependencies in Gopkg.lock * use pseudo-depdencies in build/build.go to convince dep * update Travis, Dockerfile and Docs * build.Dockerfile image now contains the Go build dependencies * => faster builds * bump pdu file after protoc update fixes #106 --- .travis.yml | 4 +- Gopkg.lock | 39 ++++++++++- Gopkg.toml | 8 +++ build.Dockerfile | 13 +++- build/build.go | 22 ++++++ docs/installation.rst | 6 +- lazy.sh | 29 ++++---- replication/pdu/pdu.pb.go | 144 ++++++++++++++++---------------------- 8 files changed, 162 insertions(+), 103 deletions(-) create mode 100644 build/build.go diff --git a/.travis.yml b/.travis.yml index b670a05..8aafc51 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,7 +24,7 @@ matrix: - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip - echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c - sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip - - ./lazy.sh builddep + - ./lazy.sh godep - make vendordeps script: - make @@ -43,7 +43,7 @@ matrix: - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip - echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c - sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip - - ./lazy.sh builddep + - ./lazy.sh godep - make vendordeps script: - make diff --git a/Gopkg.lock b/Gopkg.lock index e3785d6..6ad511b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -1,6 +1,14 @@ # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. +[[projects]] + branch = "master" + digest = "1:8cf2cf1ab10480b5e0df950dac1517aaabde05d055d9d955652997ae4b9ecbbf" + name = "github.com/alvaroloes/enumer" + packages = ["."] + pruneopts = "" + revision = "6bcfe2edaac32ad71b88ce4cf92d34cd643e4ecb" + [[projects]] branch = "master" digest = "1:c0bec5f9b98d0bc872ff5e834fac186b807b656683bd29cb82fb207a1513fabb" @@ -64,7 +72,15 @@ [[projects]] digest = "1:3dd078fda7500c341bc26cfbc6c6a34614f295a2457149fc1045cab767cbcf18" name = "github.com/golang/protobuf" - packages = ["proto"] + packages = [ + "proto", + "protoc-gen-go", + "protoc-gen-go/descriptor", + "protoc-gen-go/generator", + "protoc-gen-go/generator/internal/remap", + "protoc-gen-go/grpc", + "protoc-gen-go/plugin", + ] pruneopts = "" revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" version = "v1.2.0" @@ -141,6 +157,14 @@ revision = "3247c84500bff8d9fb6d579d800f20b3e091582c" version = "v1.0.0" +[[projects]] + branch = "master" + digest = "1:f60ff065b58bd53e641112b38bbda9d2684deb828393c7ffb89c69a1ee301d17" + name = "github.com/pascaldekloe/name" + packages = ["."] + pruneopts = "" + revision = "0fd16699aae1833640fca52a937944c6f3b1d58c" + [[projects]] digest = "1:7365acd48986e205ccb8652cc746f09c8b7876030d53710ea6ef7d0bd0dcd7ca" name = "github.com/pkg/errors" @@ -159,7 +183,7 @@ [[projects]] branch = "master" - digest = "1:1392748e290ca66ac8447ef24961f8ae9e1d846a53af0f58a5a0256982ce0577" + digest = "1:25559b520313b941b1395cd5d5ee66086b27dc15a1391c0f2aad29d5c2321f4b" name = "github.com/problame/go-netssh" packages = ["."] pruneopts = "" @@ -285,14 +309,24 @@ revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" version = "v0.3.0" +[[projects]] + branch = "master" + digest = "1:4cd780b2ee42c8eac9c02bfb6e6b52dcbaef770774458c8938f5cbfb73a7b6d3" + name = "golang.org/x/tools" + packages = ["cmd/stringer"] + pruneopts = "" + revision = "d0ca3933b724e6be513276cc2edb34e10d667438" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 input-imports = [ + "github.com/alvaroloes/enumer", "github.com/fatih/color", "github.com/gdamore/tcell/termbox", "github.com/go-logfmt/logfmt", "github.com/golang/protobuf/proto", + "github.com/golang/protobuf/protoc-gen-go", "github.com/jinzhu/copier", "github.com/kr/pretty", "github.com/mattn/go-isatty", @@ -308,6 +342,7 @@ "github.com/stretchr/testify/require", "github.com/zrepl/yaml-config", "golang.org/x/sys/unix", + "golang.org/x/tools/cmd/stringer", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 6c8a9b7..59ddabb 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -73,3 +73,11 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] [[constraint]] name = "github.com/gdamore/tcell" version = "1.0.0" + +[[constraint]] + branch = "master" + name = "golang.org/x/tools" + +[[constraint]] + branch = "master" + name = "github.com/alvaroloes/enumer" diff --git a/build.Dockerfile b/build.Dockerfile index d7729e5..95dcd66 100644 --- a/build.Dockerfile +++ b/build.Dockerfile @@ -11,12 +11,21 @@ RUN unzip -d /usr protoc-3.6.1-linux-x86_64.zip ADD lazy.sh /tmp/lazy.sh ADD docs/requirements.txt /tmp/requirements.txt ENV ZREPL_LAZY_DOCS_REQPATH=/tmp/requirements.txt -RUN /tmp/lazy.sh devsetup +RUN /tmp/lazy.sh docdep # prepare volume mount of git checkout to /zrepl RUN mkdir -p /go/src/github.com/zrepl/zrepl -RUN chmod -R 0777 /go RUN mkdir -p /.cache && chmod -R 0777 /.cache WORKDIR /go/src/github.com/zrepl/zrepl +ADD Gopkg.toml Gopkg.lock ./ + +# godep will install the Go dependencies to vendor in order to then build and install +# build dependencies like stringer to $GOPATH/bin. +# However, since users volume-mount their Git checkout into /go/src/github.com/zrepl/zrepl +# the vendor directory will be empty at build time, allowing them to experiment with +# new checkouts, etc. +# Thus, we only use the vendored deps for building dependencies. +RUN /tmp/lazy.sh godep +RUN chmod -R 0777 /go diff --git a/build/build.go b/build/build.go new file mode 100644 index 0000000..ec978e4 --- /dev/null +++ b/build/build.go @@ -0,0 +1,22 @@ +//+build windows + +// This package cannot actually be built, and since the Travis +// tests run go test ./..., we have to avoid a build attempt. +// Windows is not supported atm, so this works ¯\_(ツ)_/¯ + +// This package is a pseudo-user of build-time dependencies for zrepl, +// mostly for various code-generation tools. +// +// The imports are necessary to satisfy go dep +package main + +import ( + _ "fmt" + _ "github.com/alvaroloes/enumer" + _ "github.com/golang/protobuf/protoc-gen-go" + _ "golang.org/x/tools/cmd/stringer" +) + +func main() { + fmt.Println("just a placeholder") +} diff --git a/docs/installation.rst b/docs/installation.rst index 42800f8..a9086ca 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -66,10 +66,14 @@ Alternatively, you can install build dependencies on your local system and then mkdir -p "${GOPATH}/src/github.com/zrepl/zrepl" git clone https://github.com/zrepl/zrepl.git "${GOPATH}/src/github.com/zrepl/zrepl" cd "${GOPATH}/src/github.com/zrepl/zrepl" + python3 -m venv3 + source venv3/bin/activate ./lazy.sh devsetup make vendordeps release -Build results are located in the ``artifacts/`` directory. +The Python venv is used for the documentation build dependencies. +If you just want to build the zrepl binary, leave it out and use `./lazy.sh godep` instead. +Either way, all build results are located in the ``artifacts/`` directory. .. NOTE:: diff --git a/lazy.sh b/lazy.sh index 4b87ea8..9b0f246 100755 --- a/lazy.sh +++ b/lazy.sh @@ -25,23 +25,26 @@ fi CHECKOUTPATH="${GOPATH}/src/github.com/zrepl/zrepl" -builddep() { - step "Install build depdencies using 'go get' to \$GOPATH/bin" - go get -u golang.org/x/tools/cmd/stringer +godep() { + step "Install go dep using 'go get' to \$GOPATH/bin" go get -u github.com/golang/dep/cmd/dep - go get -u github.com/golang/protobuf/protoc-gen-go - go get -u github.com/alvaroloes/enumer - if ! type stringer || ! type dep || ! type protoc-gen-go || ! type enumer ; then + if ! type dep ; then + echo "Unable to install go dep" 1>&2 + exit 1 + fi + step "Fetching dependencies using 'dep ensure'" + dep ensure -v -vendor-only + step "go install build dependencies fetched using dep" + # these will be in the vendor directory + go build -o "$GOPATH/bin/stringer" ./vendor/golang.org/x/tools/cmd/stringer + go build -o "$GOPATH/bin/protoc-gen-go" ./vendor/github.com/golang/protobuf/protoc-gen-go + go build -o "$GOPATH/bin/enumer" ./vendor/github.com/alvaroloes/enumer + if ! type stringer || ! type protoc-gen-go || ! type enumer ; then echo "Installed dependencies but can't find them in \$PATH, adjust it to contain \$GOPATH/bin" 1>&2 exit 1 fi } -godep() { - step "Fetching dependencies using 'dep ensure'" - dep ensure -} - docdep() { if ! type pip3; then step "pip3 binary not installed or not in \$PATH" 1>&2 @@ -62,13 +65,13 @@ release() { for cmd in "$@"; do case "$cmd" in - builddep|godep|docdep|release_bins|docs) + godep|docdep|release_bins|docs) eval $cmd continue ;; devsetup) step "Installing development dependencies" - builddep + godep docdep step "Development dependencies installed" continue diff --git a/replication/pdu/pdu.pb.go b/replication/pdu/pdu.pb.go index 54470e4..f2b7d37 100644 --- a/replication/pdu/pdu.pb.go +++ b/replication/pdu/pdu.pb.go @@ -3,11 +3,9 @@ package pdu -import ( - fmt "fmt" - proto "github.com/golang/protobuf/proto" - math "math" -) +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal @@ -31,7 +29,6 @@ var FilesystemVersion_VersionType_name = map[int32]string{ 0: "Snapshot", 1: "Bookmark", } - var FilesystemVersion_VersionType_value = map[string]int32{ "Snapshot": 0, "Bookmark": 1, @@ -40,9 +37,8 @@ var FilesystemVersion_VersionType_value = map[string]int32{ func (x FilesystemVersion_VersionType) String() string { return proto.EnumName(FilesystemVersion_VersionType_name, int32(x)) } - func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{5, 0} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5, 0} } type ListFilesystemReq struct { @@ -55,17 +51,16 @@ func (m *ListFilesystemReq) Reset() { *m = ListFilesystemReq{} } func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemReq) ProtoMessage() {} func (*ListFilesystemReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{0} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{0} } - func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b) } func (m *ListFilesystemReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ListFilesystemReq.Marshal(b, m, deterministic) } -func (m *ListFilesystemReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_ListFilesystemReq.Merge(m, src) +func (dst *ListFilesystemReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListFilesystemReq.Merge(dst, src) } func (m *ListFilesystemReq) XXX_Size() int { return xxx_messageInfo_ListFilesystemReq.Size(m) @@ -87,17 +82,16 @@ func (m *ListFilesystemRes) Reset() { *m = ListFilesystemRes{} } func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemRes) ProtoMessage() {} func (*ListFilesystemRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{1} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{1} } - func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b) } func (m *ListFilesystemRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ListFilesystemRes.Marshal(b, m, deterministic) } -func (m *ListFilesystemRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_ListFilesystemRes.Merge(m, src) +func (dst *ListFilesystemRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListFilesystemRes.Merge(dst, src) } func (m *ListFilesystemRes) XXX_Size() int { return xxx_messageInfo_ListFilesystemRes.Size(m) @@ -127,17 +121,16 @@ func (m *Filesystem) Reset() { *m = Filesystem{} } func (m *Filesystem) String() string { return proto.CompactTextString(m) } func (*Filesystem) ProtoMessage() {} func (*Filesystem) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{2} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{2} } - func (m *Filesystem) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Filesystem.Unmarshal(m, b) } func (m *Filesystem) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Filesystem.Marshal(b, m, deterministic) } -func (m *Filesystem) XXX_Merge(src proto.Message) { - xxx_messageInfo_Filesystem.Merge(m, src) +func (dst *Filesystem) XXX_Merge(src proto.Message) { + xxx_messageInfo_Filesystem.Merge(dst, src) } func (m *Filesystem) XXX_Size() int { return xxx_messageInfo_Filesystem.Size(m) @@ -173,17 +166,16 @@ func (m *ListFilesystemVersionsReq) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsReq) ProtoMessage() {} func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{3} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{3} } - func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b) } func (m *ListFilesystemVersionsReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ListFilesystemVersionsReq.Marshal(b, m, deterministic) } -func (m *ListFilesystemVersionsReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_ListFilesystemVersionsReq.Merge(m, src) +func (dst *ListFilesystemVersionsReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListFilesystemVersionsReq.Merge(dst, src) } func (m *ListFilesystemVersionsReq) XXX_Size() int { return xxx_messageInfo_ListFilesystemVersionsReq.Size(m) @@ -212,17 +204,16 @@ func (m *ListFilesystemVersionsRes) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsRes) ProtoMessage() {} func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{4} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{4} } - func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b) } func (m *ListFilesystemVersionsRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ListFilesystemVersionsRes.Marshal(b, m, deterministic) } -func (m *ListFilesystemVersionsRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_ListFilesystemVersionsRes.Merge(m, src) +func (dst *ListFilesystemVersionsRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_ListFilesystemVersionsRes.Merge(dst, src) } func (m *ListFilesystemVersionsRes) XXX_Size() int { return xxx_messageInfo_ListFilesystemVersionsRes.Size(m) @@ -255,17 +246,16 @@ func (m *FilesystemVersion) Reset() { *m = FilesystemVersion{} } func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) } func (*FilesystemVersion) ProtoMessage() {} func (*FilesystemVersion) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{5} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5} } - func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b) } func (m *FilesystemVersion) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_FilesystemVersion.Marshal(b, m, deterministic) } -func (m *FilesystemVersion) XXX_Merge(src proto.Message) { - xxx_messageInfo_FilesystemVersion.Merge(m, src) +func (dst *FilesystemVersion) XXX_Merge(src proto.Message) { + xxx_messageInfo_FilesystemVersion.Merge(dst, src) } func (m *FilesystemVersion) XXX_Size() int { return xxx_messageInfo_FilesystemVersion.Size(m) @@ -336,17 +326,16 @@ func (m *SendReq) Reset() { *m = SendReq{} } func (m *SendReq) String() string { return proto.CompactTextString(m) } func (*SendReq) ProtoMessage() {} func (*SendReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{6} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{6} } - func (m *SendReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendReq.Unmarshal(m, b) } func (m *SendReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_SendReq.Marshal(b, m, deterministic) } -func (m *SendReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_SendReq.Merge(m, src) +func (dst *SendReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_SendReq.Merge(dst, src) } func (m *SendReq) XXX_Size() int { return xxx_messageInfo_SendReq.Size(m) @@ -418,17 +407,16 @@ func (m *Property) Reset() { *m = Property{} } func (m *Property) String() string { return proto.CompactTextString(m) } func (*Property) ProtoMessage() {} func (*Property) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{7} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{7} } - func (m *Property) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Property.Unmarshal(m, b) } func (m *Property) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Property.Marshal(b, m, deterministic) } -func (m *Property) XXX_Merge(src proto.Message) { - xxx_messageInfo_Property.Merge(m, src) +func (dst *Property) XXX_Merge(src proto.Message) { + xxx_messageInfo_Property.Merge(dst, src) } func (m *Property) XXX_Size() int { return xxx_messageInfo_Property.Size(m) @@ -469,17 +457,16 @@ func (m *SendRes) Reset() { *m = SendRes{} } func (m *SendRes) String() string { return proto.CompactTextString(m) } func (*SendRes) ProtoMessage() {} func (*SendRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{8} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{8} } - func (m *SendRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendRes.Unmarshal(m, b) } func (m *SendRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_SendRes.Marshal(b, m, deterministic) } -func (m *SendRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_SendRes.Merge(m, src) +func (dst *SendRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_SendRes.Merge(dst, src) } func (m *SendRes) XXX_Size() int { return xxx_messageInfo_SendRes.Size(m) @@ -524,17 +511,16 @@ func (m *ReceiveReq) Reset() { *m = ReceiveReq{} } func (m *ReceiveReq) String() string { return proto.CompactTextString(m) } func (*ReceiveReq) ProtoMessage() {} func (*ReceiveReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{9} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{9} } - func (m *ReceiveReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveReq.Unmarshal(m, b) } func (m *ReceiveReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReceiveReq.Marshal(b, m, deterministic) } -func (m *ReceiveReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReceiveReq.Merge(m, src) +func (dst *ReceiveReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReceiveReq.Merge(dst, src) } func (m *ReceiveReq) XXX_Size() int { return xxx_messageInfo_ReceiveReq.Size(m) @@ -569,17 +555,16 @@ func (m *ReceiveRes) Reset() { *m = ReceiveRes{} } func (m *ReceiveRes) String() string { return proto.CompactTextString(m) } func (*ReceiveRes) ProtoMessage() {} func (*ReceiveRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{10} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{10} } - func (m *ReceiveRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveRes.Unmarshal(m, b) } func (m *ReceiveRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReceiveRes.Marshal(b, m, deterministic) } -func (m *ReceiveRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReceiveRes.Merge(m, src) +func (dst *ReceiveRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReceiveRes.Merge(dst, src) } func (m *ReceiveRes) XXX_Size() int { return xxx_messageInfo_ReceiveRes.Size(m) @@ -603,17 +588,16 @@ func (m *DestroySnapshotsReq) Reset() { *m = DestroySnapshotsReq{} } func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsReq) ProtoMessage() {} func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{11} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{11} } - func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b) } func (m *DestroySnapshotsReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_DestroySnapshotsReq.Marshal(b, m, deterministic) } -func (m *DestroySnapshotsReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_DestroySnapshotsReq.Merge(m, src) +func (dst *DestroySnapshotsReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_DestroySnapshotsReq.Merge(dst, src) } func (m *DestroySnapshotsReq) XXX_Size() int { return xxx_messageInfo_DestroySnapshotsReq.Size(m) @@ -650,17 +634,16 @@ func (m *DestroySnapshotRes) Reset() { *m = DestroySnapshotRes{} } func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotRes) ProtoMessage() {} func (*DestroySnapshotRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{12} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{12} } - func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b) } func (m *DestroySnapshotRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_DestroySnapshotRes.Marshal(b, m, deterministic) } -func (m *DestroySnapshotRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_DestroySnapshotRes.Merge(m, src) +func (dst *DestroySnapshotRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_DestroySnapshotRes.Merge(dst, src) } func (m *DestroySnapshotRes) XXX_Size() int { return xxx_messageInfo_DestroySnapshotRes.Size(m) @@ -696,17 +679,16 @@ func (m *DestroySnapshotsRes) Reset() { *m = DestroySnapshotsRes{} } func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsRes) ProtoMessage() {} func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{13} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{13} } - func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b) } func (m *DestroySnapshotsRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_DestroySnapshotsRes.Marshal(b, m, deterministic) } -func (m *DestroySnapshotsRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_DestroySnapshotsRes.Merge(m, src) +func (dst *DestroySnapshotsRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_DestroySnapshotsRes.Merge(dst, src) } func (m *DestroySnapshotsRes) XXX_Size() int { return xxx_messageInfo_DestroySnapshotsRes.Size(m) @@ -739,17 +721,16 @@ func (m *ReplicationCursorReq) Reset() { *m = ReplicationCursorReq{} } func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq) ProtoMessage() {} func (*ReplicationCursorReq) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{14} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14} } - func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b) } func (m *ReplicationCursorReq) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReplicationCursorReq.Marshal(b, m, deterministic) } -func (m *ReplicationCursorReq) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReplicationCursorReq.Merge(m, src) +func (dst *ReplicationCursorReq) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReplicationCursorReq.Merge(dst, src) } func (m *ReplicationCursorReq) XXX_Size() int { return xxx_messageInfo_ReplicationCursorReq.Size(m) @@ -888,17 +869,16 @@ func (m *ReplicationCursorReq_GetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_GetOp) ProtoMessage() {} func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{14, 0} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 0} } - func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b) } func (m *ReplicationCursorReq_GetOp) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReplicationCursorReq_GetOp.Marshal(b, m, deterministic) } -func (m *ReplicationCursorReq_GetOp) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReplicationCursorReq_GetOp.Merge(m, src) +func (dst *ReplicationCursorReq_GetOp) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReplicationCursorReq_GetOp.Merge(dst, src) } func (m *ReplicationCursorReq_GetOp) XXX_Size() int { return xxx_messageInfo_ReplicationCursorReq_GetOp.Size(m) @@ -920,17 +900,16 @@ func (m *ReplicationCursorReq_SetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_SetOp) ProtoMessage() {} func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{14, 1} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 1} } - func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b) } func (m *ReplicationCursorReq_SetOp) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReplicationCursorReq_SetOp.Marshal(b, m, deterministic) } -func (m *ReplicationCursorReq_SetOp) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReplicationCursorReq_SetOp.Merge(m, src) +func (dst *ReplicationCursorReq_SetOp) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReplicationCursorReq_SetOp.Merge(dst, src) } func (m *ReplicationCursorReq_SetOp) XXX_Size() int { return xxx_messageInfo_ReplicationCursorReq_SetOp.Size(m) @@ -962,17 +941,16 @@ func (m *ReplicationCursorRes) Reset() { *m = ReplicationCursorRes{} } func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorRes) ProtoMessage() {} func (*ReplicationCursorRes) Descriptor() ([]byte, []int) { - return fileDescriptor_5e683fe3d6db3968, []int{15} + return fileDescriptor_pdu_fe566e6b212fcf8d, []int{15} } - func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b) } func (m *ReplicationCursorRes) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_ReplicationCursorRes.Marshal(b, m, deterministic) } -func (m *ReplicationCursorRes) XXX_Merge(src proto.Message) { - xxx_messageInfo_ReplicationCursorRes.Merge(m, src) +func (dst *ReplicationCursorRes) XXX_Merge(src proto.Message) { + xxx_messageInfo_ReplicationCursorRes.Merge(dst, src) } func (m *ReplicationCursorRes) XXX_Size() int { return xxx_messageInfo_ReplicationCursorRes.Size(m) @@ -1089,7 +1067,6 @@ func _ReplicationCursorRes_OneofSizer(msg proto.Message) (n int) { } func init() { - proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) proto.RegisterType((*ListFilesystemReq)(nil), "pdu.ListFilesystemReq") proto.RegisterType((*ListFilesystemRes)(nil), "pdu.ListFilesystemRes") proto.RegisterType((*Filesystem)(nil), "pdu.Filesystem") @@ -1108,11 +1085,12 @@ func init() { proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "pdu.ReplicationCursorReq.GetOp") proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "pdu.ReplicationCursorReq.SetOp") proto.RegisterType((*ReplicationCursorRes)(nil), "pdu.ReplicationCursorRes") + proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) } -func init() { proto.RegisterFile("pdu.proto", fileDescriptor_5e683fe3d6db3968) } +func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_fe566e6b212fcf8d) } -var fileDescriptor_5e683fe3d6db3968 = []byte{ +var fileDescriptor_pdu_fe566e6b212fcf8d = []byte{ // 659 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdb, 0x6e, 0x13, 0x31, 0x10, 0xcd, 0xe6, 0xba, 0x99, 0x94, 0x5e, 0xdc, 0xaa, 0x2c, 0x15, 0x82, 0xc8, 0xbc, 0x04, 0x24, From 1aae7b222f0600010f8472a617e294e174985211 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sat, 1 Dec 2018 14:55:22 +0100 Subject: [PATCH 03/20] docs: fix confusing description of the role of client identity for sink jobs --- docs/configuration/jobs.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/jobs.rst b/docs/configuration/jobs.rst index a7bdcae..077784f 100644 --- a/docs/configuration/jobs.rst +++ b/docs/configuration/jobs.rst @@ -57,7 +57,7 @@ using the transport listener type specified in the ``serve`` field of the job co Each transport listener provides a client's identity to the passive side job. It uses the client identity for access control: -* The ``sink`` job only allows pushes to those ZFS filesystems to the active side that are located below ``root_fs/${client_identity}``. +* The ``sink`` job maps requests from different client identities to their respective sub-filesystem tree ``root_fs/${client_identity}``. * The ``source`` job has a whitelist of client identities that are allowed pull access. .. TIP:: From 7a75a4d384ef9f39606c47e975d65e44f259901a Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 6 Nov 2018 23:37:25 +0100 Subject: [PATCH 04/20] util/iocommand: timeout kill on close + other hardening --- util/iocommand.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/util/iocommand.go b/util/iocommand.go index 1113f74..fe5f9a3 100644 --- a/util/iocommand.go +++ b/util/iocommand.go @@ -4,15 +4,18 @@ import ( "bytes" "context" "fmt" + "github.com/zrepl/zrepl/util/envconst" "io" "os" "os/exec" "syscall" + "time" ) // An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. type IOCommand struct { Cmd *exec.Cmd + kill context.CancelFunc Stdin io.WriteCloser Stdout io.ReadCloser StderrBuf *bytes.Buffer @@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS c = &IOCommand{} + ctx, c.kill = context.WithCancel(ctx) c.Cmd = exec.CommandContext(ctx, command, args...) if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { @@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) { func (c *IOCommand) Read(buf []byte) (n int, err error) { n, err = c.Stdout.Read(buf) if err == io.EOF { - if waitErr := c.doWait(); waitErr != nil { + if waitErr := c.doWait(context.Background()); waitErr != nil { err = waitErr } } return } -func (c *IOCommand) doWait() (err error) { +func (c *IOCommand) doWait(ctx context.Context) (err error) { + go func() { + dl, ok := ctx.Deadline() + if !ok { + return + } + time.Sleep(dl.Sub(time.Now())) + c.kill() + c.Stdout.Close() + c.Stdin.Close() + }() waitErr := c.Cmd.Wait() var wasUs bool = false var waitStatus syscall.WaitStatus @@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) { if c.Cmd.ProcessState == nil { // racy... err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM) - if err != nil { - return - } - return c.doWait() + ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second)) + defer cancel() + return c.doWait(ctx) } else { return c.ExitResult.Error } From 68b62a5c008be8d701c2c9028e0e02c452ea1be8 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 21:19:48 +0100 Subject: [PATCH 05/20] tlsconf: clear handshake deadline after completed handshake --- tlsconf/tlsconf.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index 48fc382..6547518 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -70,6 +70,7 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { if err = tlsConn.Handshake(); err != nil { goto CloseAndErr } + tlsConn.SetDeadline(time.Time{}) peerCerts = tlsConn.ConnectionState().PeerCertificates if len(peerCerts) != 1 { From ef3283638a7c46a0a3f752d62bb19ff0dd4df8b8 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 21:19:57 +0100 Subject: [PATCH 06/20] logger: add stderrlogger (sometimes useful) --- logger/stderrlogger.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 logger/stderrlogger.go diff --git a/logger/stderrlogger.go b/logger/stderrlogger.go new file mode 100644 index 0000000..6ca0aad --- /dev/null +++ b/logger/stderrlogger.go @@ -0,0 +1,27 @@ +package logger + +import ( + "fmt" + "os" +) + +type stderrLogger struct { + Logger +} + +type stderrLoggerOutlet struct {} + +func (stderrLoggerOutlet) WriteEntry(entry Entry) error { + fmt.Fprintf(os.Stderr, "%#v\n", entry) + return nil +} + +var _ Logger = testLogger{} + +func NewStderrDebugLogger() Logger { + outlets := NewOutlets() + outlets.Add(&stderrLoggerOutlet{}, Debug) + return &testLogger{ + Logger: NewLogger(outlets, 0), + } +} From c1aab0bee92e9d2a623554a62ad8cbdb2aad4ba9 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 21:54:36 +0100 Subject: [PATCH 07/20] config: update yaml-config and use zeropositive constraint for timeouts --- Gopkg.lock | 4 ++-- config/config.go | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 6ad511b..a11b64a 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -281,11 +281,11 @@ [[projects]] branch = "v2" - digest = "1:9d92186f609a73744232323416ddafd56fae67cb552162cc190ab903e36900dd" + digest = "1:6b8a6afafde7ed31cd0c577ba40d88ce39e8f1c5eb76d7836be7d5b74f1c534a" name = "github.com/zrepl/yaml-config" packages = ["."] pruneopts = "" - revision = "af27d27978ad95808723a62d87557d63c3ff0605" + revision = "08227ad854131f7dfcdfb12579fb73dd8a38a03a" [[projects]] branch = "master" diff --git a/config/config.go b/config/config.go index eeea023..effa8f8 100644 --- a/config/config.go +++ b/config/config.go @@ -167,7 +167,7 @@ type ConnectCommon struct { type TCPConnect struct { ConnectCommon `yaml:",inline"` Address string `yaml:"address"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type TLSConnect struct { @@ -177,7 +177,7 @@ type TLSConnect struct { Cert string `yaml:"cert"` Key string `yaml:"key"` ServerCN string `yaml:"server_cn"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type SSHStdinserverConnect struct { @@ -189,7 +189,7 @@ type SSHStdinserverConnect struct { TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused SSHCommand string `yaml:"ssh_command,optional"` //TODO unused Options []string `yaml:"options,optional"` - DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` + DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"` } type LocalConnect struct { @@ -220,7 +220,7 @@ type TLSServe struct { Cert string `yaml:"cert"` Key string `yaml:"key"` ClientCNs []string `yaml:"client_cns"` - HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"` + HandshakeTimeout time.Duration `yaml:"handshake_timeout,zeropositive,default=10s"` } type StdinserverServer struct { From 38b0bd76f5115891efc7fde2d6afe5c9455dbc9e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 22:00:03 +0100 Subject: [PATCH 08/20] build: just use go {test,vet} ./... for targets vet, test and generate --- .travis.yml | 2 -- Makefile | 55 +++-------------------------------------------------- 2 files changed, 3 insertions(+), 54 deletions(-) diff --git a/.travis.yml b/.travis.yml index 8aafc51..2321376 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,7 +30,6 @@ matrix: - make - make vet - make test - - go test ./... - make artifacts/zrepl-freebsd-amd64 - make artifacts/zrepl-linux-amd64 - make artifacts/zrepl-darwin-amd64 @@ -49,7 +48,6 @@ matrix: - make - make vet - make test - - go test ./... - make artifacts/zrepl-freebsd-amd64 - make artifacts/zrepl-linux-amd64 - make artifacts/zrepl-darwin-amd64 diff --git a/Makefile b/Makefile index a0f31a6..6036caa 100644 --- a/Makefile +++ b/Makefile @@ -1,38 +1,6 @@ .PHONY: generate build test vet cover release docs docs-clean clean vendordeps .DEFAULT_GOAL := build -ROOT := github.com/zrepl/zrepl -SUBPKGS += client -SUBPKGS += config -SUBPKGS += daemon -SUBPKGS += daemon/filters -SUBPKGS += daemon/job -SUBPKGS += daemon/logging -SUBPKGS += daemon/nethelpers -SUBPKGS += daemon/pruner -SUBPKGS += daemon/snapper -SUBPKGS += daemon/streamrpcconfig -SUBPKGS += daemon/transport -SUBPKGS += daemon/transport/connecter -SUBPKGS += daemon/transport/serve -SUBPKGS += endpoint -SUBPKGS += logger -SUBPKGS += pruning -SUBPKGS += pruning/retentiongrid -SUBPKGS += replication -SUBPKGS += replication/fsrep -SUBPKGS += replication/pdu -SUBPKGS += replication/internal/diff -SUBPKGS += tlsconf -SUBPKGS += util -SUBPKGS += util/socketpair -SUBPKGS += util/watchdog -SUBPKGS += util/envconst -SUBPKGS += version -SUBPKGS += zfs - -_TESTPKGS := $(ROOT) $(foreach p,$(SUBPKGS),$(ROOT)/$(p)) - ARTIFACTDIR := artifacts ifdef ZREPL_VERSION @@ -60,34 +28,17 @@ vendordeps: generate: #not part of the build, must do that manually protoc -I=replication/pdu --go_out=replication/pdu replication/pdu/pdu.proto - @for pkg in $(_TESTPKGS); do\ - go generate "$$pkg" || exit 1; \ - done; + go generate -x ./... build: @echo "INFO: In case of missing dependencies, run 'make vendordeps'" $(GO_BUILD) -o "$(ARTIFACTDIR)/zrepl" test: - @for pkg in $(_TESTPKGS); do \ - echo "Testing $$pkg"; \ - go test "$$pkg" || exit 1; \ - done; + go test ./... vet: - @for pkg in $(_TESTPKGS); do \ - echo "Vetting $$pkg"; \ - go vet "$$pkg" || exit 1; \ - done; - -cover: artifacts - @for pkg in $(_TESTPKGS); do \ - profile="$(ARTIFACTDIR)/cover-$$(basename $$pkg).out"; \ - go test -coverprofile "$$profile" $$pkg || exit 1; \ - if [ -f "$$profile" ]; then \ - go tool cover -html="$$profile" -o "$${profile}.html" || exit 2; \ - fi; \ - done; + go vet ./... $(ARTIFACTDIR): mkdir -p "$@" From bb5278fe9bd6e3de295191c2d6e4ace3174d66ce Mon Sep 17 00:00:00 2001 From: Josh Souza Date: Fri, 30 Nov 2018 16:34:29 -0700 Subject: [PATCH 09/20] Permit peers to provide a cert chain (multiple certs). fixes #103 --- tlsconf/tlsconf.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index 48fc382..38f0734 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -72,7 +72,7 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { } peerCerts = tlsConn.ConnectionState().PeerCertificates - if len(peerCerts) != 1 { + if len(peerCerts) < 1 { err = errors.New("unexpected number of certificates presented by TLS client") goto CloseAndErr } From f724480c7b25936b7c766108f7913e90e78f7176 Mon Sep 17 00:00:00 2001 From: Josh Souza Date: Tue, 22 Jan 2019 10:09:24 -0800 Subject: [PATCH 10/20] Add documentation regarding using a certificate chain --- docs/configuration/transports.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/configuration/transports.rst b/docs/configuration/transports.rst index a705bc9..abd5347 100644 --- a/docs/configuration/transports.rst +++ b/docs/configuration/transports.rst @@ -77,6 +77,8 @@ Connect The ``tls`` transport uses TCP + TLS with client authentication using client certificates. The client identity is the common name (CN) presented in the client certificate. It is recommended to set up a dedicated CA infrastructure for this transport, e.g. using OpenVPN's `EasyRSA `_. +When utilizing a CA infrastructure, provide a full chain certificate with the sender's certificate first in the list, with each following certificate directly certifying the one preceding it, per `TLS's specification`. + For a simple 2-machine setup, see the :ref:`instructions below`. The implementation uses `Go's TLS library `_. From 3105fa4ff8c2a0f70e9fce79179d370cea93734f Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 21:44:26 +0100 Subject: [PATCH 11/20] build: use dep's required feature for dev tools --- Gopkg.toml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/Gopkg.toml b/Gopkg.toml index 59ddabb..c213f42 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,5 +1,10 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] +required = [ + "golang.org/x/tools/cmd/stringer", + "github.com/alvaroloes/enumer", +] + [[constraint]] branch = "master" name = "github.com/ftrvxmtrx/fd" @@ -73,11 +78,3 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] [[constraint]] name = "github.com/gdamore/tcell" version = "1.0.0" - -[[constraint]] - branch = "master" - name = "golang.org/x/tools" - -[[constraint]] - branch = "master" - name = "github.com/alvaroloes/enumer" From ea719f5b5a1a3872d07377b64d6e1b7c40a3a839 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sat, 5 Jan 2019 21:53:59 +0100 Subject: [PATCH 12/20] build: use 'git describe --always' to determine ZREPL_VERSION --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 6036caa..343af9b 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ ifdef ZREPL_VERSION _ZREPL_VERSION := $(ZREPL_VERSION) endif ifndef _ZREPL_VERSION - _ZREPL_VERSION := $(shell git describe --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" ) + _ZREPL_VERSION := $(shell git describe --always --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" ) ifeq ($(_ZREPL_VERSION),ZREPL_BUILD_INVALID_VERSION) # can't use .SHELLSTATUS because Debian Stretch is still on gmake 4.1 $(error cannot infer variable ZREPL_VERSION using git and variable is not overriden by make invocation) endif @@ -83,7 +83,7 @@ release: $(RELEASE_BINS) $(RELEASE_NOARCH) cp $^ "$(ARTIFACTDIR)/release" cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt @# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override - @if git describe --dirty 2>/dev/null | grep dirty >/dev/null; then \ + @if git describe --always --dirty 2>/dev/null | grep dirty >/dev/null; then \ echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \ if [ "$(ZREPL_VERSION)" = "" ]; then \ echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \ From 25c974f0b57eb09d1e04c1b3dc78356e6bb96a17 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sun, 30 Dec 2018 20:43:51 +0100 Subject: [PATCH 13/20] envconst: support for int64 --- util/envconst/envconst.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/util/envconst/envconst.go b/util/envconst/envconst.go index 8159aae..8c13190 100644 --- a/util/envconst/envconst.go +++ b/util/envconst/envconst.go @@ -2,6 +2,7 @@ package envconst import ( "os" + "strconv" "sync" "time" ) @@ -23,3 +24,19 @@ func Duration(varname string, def time.Duration) time.Duration { cache.Store(varname, d) return d } + +func Int64(varname string, def int64) int64 { + if v, ok := cache.Load(varname); ok { + return v.(int64) + } + e := os.Getenv(varname) + if e == "" { + return def + } + d, err := strconv.ParseInt(e, 10, 64) + if err != nil { + panic(err) + } + cache.Store(varname, d) + return d +} From 76a6c623f3fa16dccc49ae88b3d2b6664d1fb7ac Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Thu, 17 Jan 2019 01:43:39 +0100 Subject: [PATCH 14/20] tlsconf and transport/tls: support NSS-formatted keylog file for debugging ... via env variable --- tlsconf/tlsconf.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index 6547518..ffe6094 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -4,8 +4,11 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" + "io" "io/ioutil" "net" + "os" "time" ) @@ -42,6 +45,7 @@ func NewClientAuthListener( ClientCAs: ca, ClientAuth: tls.RequireAndVerifyClientCert, PreferServerCipherSuites: true, + KeyLogWriter: keylogFromEnv(), } l = tls.NewListener(l, &tlsConf) return &ClientAuthListener{ @@ -106,7 +110,21 @@ func ClientAuthClient(serverName string, rootCA *x509.CertPool, clientCert tls.C Certificates: []tls.Certificate{clientCert}, RootCAs: rootCA, ServerName: serverName, + KeyLogWriter: keylogFromEnv(), } tlsConfig.BuildNameToCertificate() return tlsConfig, nil } + +func keylogFromEnv() io.Writer { + var keyLog io.Writer = nil + if outfile := os.Getenv("ZREPL_KEYLOG_FILE"); outfile != "" { + fmt.Fprintf(os.Stderr, "writing to key log %s\n", outfile) + var err error + keyLog, err = os.OpenFile(outfile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(err) + } + } + return keyLog +} From d281fb00e3388321557efcf5167babec09bf3203 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sun, 30 Dec 2018 20:22:54 +0100 Subject: [PATCH 15/20] socketpair: directly export *net.UnixConn (and add test for that behavior) --- util/socketpair/socketpair.go | 28 +++++++++------------------- util/socketpair/socketpair_test.go | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 19 deletions(-) create mode 100644 util/socketpair/socketpair_test.go diff --git a/util/socketpair/socketpair.go b/util/socketpair/socketpair.go index 615c8f2..c1da9e3 100644 --- a/util/socketpair/socketpair.go +++ b/util/socketpair/socketpair.go @@ -1,42 +1,32 @@ package socketpair import ( - "golang.org/x/sys/unix" "net" "os" + + "golang.org/x/sys/unix" ) -type fileConn struct { - net.Conn // net.FileConn - f *os.File -} -func (c fileConn) Close() error { - if err := c.Conn.Close(); err != nil { - return err - } - if err := c.f.Close(); err != nil { - return err - } - return nil -} - -func SocketPair() (a, b net.Conn, err error) { +func SocketPair() (a, b *net.UnixConn, err error) { // don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) if err != nil { return nil, nil, err } - toConn := func(fd int) (net.Conn, error) { + toConn := func(fd int) (*net.UnixConn, error) { f := os.NewFile(uintptr(fd), "fileconn") if f == nil { panic(fd) } c, err := net.FileConn(f) + f.Close() // net.FileConn uses dup under the hood if err != nil { - f.Close() return nil, err } - return fileConn{Conn: c, f: f}, nil + // strictly, the following type assertion is an implementation detail + // however, will be caught by test TestSocketPairWorks + fileConnIsUnixConn := c.(*net.UnixConn) + return fileConnIsUnixConn, nil } if a, err = toConn(sockpair[0]); err != nil { // shadowing return nil, nil, err diff --git a/util/socketpair/socketpair_test.go b/util/socketpair/socketpair_test.go new file mode 100644 index 0000000..95cbdc5 --- /dev/null +++ b/util/socketpair/socketpair_test.go @@ -0,0 +1,18 @@ +package socketpair + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// This is test is mostly to verify that the assumption about +// net.FileConn returning *net.UnixConn for AF_UNIX FDs works. +func TestSocketPairWorks(t *testing.T) { + assert.NotPanics(t, func() { + a, b, err := SocketPair() + assert.NoError(t, err) + a.Close() + b.Close() + }) +} From 796c5ad42d90d8017e244955b3ea1e4be714ddab Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 11 Dec 2018 22:01:50 +0100 Subject: [PATCH 16/20] rpc rewrite: control RPCs using gRPC + separate RPC for data transfer transport/ssh: update go-netssh to new version => supports CloseWrite and Deadlines => build: require Go 1.11 (netssh requires it) --- .travis.yml | 4 +- Gopkg.lock | 121 ++++- Gopkg.toml | 35 +- Makefile | 7 +- config/config.go | 17 +- config/config_rpc_test.go | 86 ---- config/config_test.go | 24 +- daemon/job/active.go | 134 +++-- daemon/job/passive.go | 128 ++--- daemon/logging/adaptors.go | 32 -- daemon/logging/build_logging.go | 48 +- daemon/prometheus.go | 5 + daemon/pruner/pruner.go | 15 +- daemon/pruner/pruner_test.go | 9 +- daemon/snapper/snapper.go | 8 +- daemon/streamrpcconfig/streamrpcconfig.go | 25 - daemon/transport/connecter/connecter.go | 84 ---- daemon/transport/handshake.go | 136 ------ daemon/transport/serve/serve.go | 147 ------ daemon/transport/serve/serve_tls.go | 83 ---- docs/configuration/transports.rst | 14 +- endpoint/context.go | 1 + endpoint/endpoint.go | 457 +++++------------- replication/fsrep/fsfsm.go | 57 ++- replication/mainfsm.go | 24 +- replication/pdu/pdu.pb.go | 362 ++++++++++---- replication/pdu/pdu.proto | 22 +- rpc/dataconn/base2bufpool/base2bufpool.go | 169 +++++++ .../base2bufpool/base2bufpool_test.go | 98 ++++ .../base2bufpool/nofitbehavior_enumer.go | 51 ++ rpc/dataconn/dataconn_client.go | 215 ++++++++ rpc/dataconn/dataconn_debug.go | 20 + rpc/dataconn/dataconn_server.go | 178 +++++++ rpc/dataconn/dataconn_shared.go | 70 +++ rpc/dataconn/dataconn_test.go | 1 + rpc/dataconn/frameconn/frameconn.go | 346 +++++++++++++ .../frameconn/frameconn_prometheus.go | 63 +++ .../frameconn/frameconn_shutdown_fsm.go | 37 ++ rpc/dataconn/frameconn/frameconn_test.go | 22 + rpc/dataconn/heartbeatconn/heartbeatconn.go | 137 ++++++ .../heartbeatconn/heartbeatconn_debug.go | 20 + .../heartbeatconn/heartbeatconn_test.go | 26 + rpc/dataconn/microbenchmark/microbenchmark.go | 135 ++++++ rpc/dataconn/stream/stream.go | 269 +++++++++++ rpc/dataconn/stream/stream_conn.go | 194 ++++++++ rpc/dataconn/stream/stream_debug.go | 20 + rpc/dataconn/stream/stream_test.go | 131 +++++ .../internal/wireevaluator/testbed/.gitignore | 12 + .../internal/wireevaluator/testbed/README.md | 15 + .../internal/wireevaluator/testbed/all.yml | 17 + .../wireevaluator/testbed/gen_files.sh | 55 +++ .../internal_prepare_and_run_repeated.yml | 54 +++ .../internal_run_test_prepared_single.yml | 38 ++ .../wireevaluator/testbed/inventory.example | 2 + .../testbed/templates/ssh.yml.j2 | 13 + .../testbed/templates/tcp.yml.j2 | 10 + .../testbed/templates/tls.yml.j2 | 16 + .../internal/wireevaluator/wireevaluator.go | 111 +++++ .../wireevaluator/wireevaluator_closewrite.go | 110 +++++ .../wireevaluator/wireevaluator_deadlines.go | 138 ++++++ rpc/dataconn/timeoutconn/timeoutconn.go | 288 +++++++++++ rpc/dataconn/timeoutconn/timeoutconn_debug.go | 19 + rpc/dataconn/timeoutconn/timeoutconn_test.go | 177 +++++++ .../authlistener_grpc_adaptor.go | 118 +++++ rpc/grpcclientidentity/example/grpcauth.proto | 16 + rpc/grpcclientidentity/example/main.go | 107 ++++ .../example/pdu/grpcauth.pb.go | 193 ++++++++ .../authlistener_grpc_adaptor_wrapper.go | 76 +++ .../authlistener_netlistener_adaptor.go | 102 ++++ rpc/rpc_client.go | 106 ++++ rpc/rpc_debug.go | 20 + rpc/rpc_doc.go | 118 +++++ rpc/rpc_logging.go | 34 ++ rpc/rpc_mux.go | 57 +++ rpc/rpc_server.go | 119 +++++ rpc/transportmux/transportmux.go | 205 ++++++++ rpc/versionhandshake/versionhandshake.go | 181 +++++++ .../versionhandshake/versionhandshake_test.go | 2 +- .../versionhandshake_transport_wrappers.go | 66 +++ tlsconf/tlsconf.go | 35 +- transport/fromconfig/transport_fromconfig.go | 58 +++ .../local}/connect_local.go | 9 +- .../serve => transport/local}/serve_local.go | 56 +-- .../ssh}/connect_ssh.go | 19 +- .../ssh}/serve_stdinserver.go | 97 ++-- .../tcp}/connect_tcp.go | 14 +- .../serve => transport/tcp}/serve_tcp.go | 37 +- .../tls}/connect_tls.go | 14 +- transport/tls/serve_tls.go | 89 ++++ transport/tls/tls_wire_adaptor.go | 47 ++ transport/transport.go | 84 ++++ util/bytecounter/bytecounter_streamcopier.go | 71 +++ .../bytecounter_streamcopier_test.go | 38 ++ zfs/diff.go | 2 +- zfs/mapping.go | 8 +- zfs/zfs.go | 324 ++++++++++++- zfs/zfs_debug.go | 20 + zfs/zfs_pipe.go | 18 + zfs/zfs_pipe_linux.go | 21 + zfs/zfs_test.go | 2 +- 100 files changed, 6460 insertions(+), 1485 deletions(-) delete mode 100644 config/config_rpc_test.go delete mode 100644 daemon/logging/adaptors.go delete mode 100644 daemon/streamrpcconfig/streamrpcconfig.go delete mode 100644 daemon/transport/connecter/connecter.go delete mode 100644 daemon/transport/handshake.go delete mode 100644 daemon/transport/serve/serve.go delete mode 100644 daemon/transport/serve/serve_tls.go create mode 100644 rpc/dataconn/base2bufpool/base2bufpool.go create mode 100644 rpc/dataconn/base2bufpool/base2bufpool_test.go create mode 100644 rpc/dataconn/base2bufpool/nofitbehavior_enumer.go create mode 100644 rpc/dataconn/dataconn_client.go create mode 100644 rpc/dataconn/dataconn_debug.go create mode 100644 rpc/dataconn/dataconn_server.go create mode 100644 rpc/dataconn/dataconn_shared.go create mode 100644 rpc/dataconn/dataconn_test.go create mode 100644 rpc/dataconn/frameconn/frameconn.go create mode 100644 rpc/dataconn/frameconn/frameconn_prometheus.go create mode 100644 rpc/dataconn/frameconn/frameconn_shutdown_fsm.go create mode 100644 rpc/dataconn/frameconn/frameconn_test.go create mode 100644 rpc/dataconn/heartbeatconn/heartbeatconn.go create mode 100644 rpc/dataconn/heartbeatconn/heartbeatconn_debug.go create mode 100644 rpc/dataconn/heartbeatconn/heartbeatconn_test.go create mode 100644 rpc/dataconn/microbenchmark/microbenchmark.go create mode 100644 rpc/dataconn/stream/stream.go create mode 100644 rpc/dataconn/stream/stream_conn.go create mode 100644 rpc/dataconn/stream/stream_debug.go create mode 100644 rpc/dataconn/stream/stream_test.go create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml create mode 100755 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go create mode 100644 rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go create mode 100644 rpc/dataconn/timeoutconn/timeoutconn.go create mode 100644 rpc/dataconn/timeoutconn/timeoutconn_debug.go create mode 100644 rpc/dataconn/timeoutconn/timeoutconn_test.go create mode 100644 rpc/grpcclientidentity/authlistener_grpc_adaptor.go create mode 100644 rpc/grpcclientidentity/example/grpcauth.proto create mode 100644 rpc/grpcclientidentity/example/main.go create mode 100644 rpc/grpcclientidentity/example/pdu/grpcauth.pb.go create mode 100644 rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go create mode 100644 rpc/netadaptor/authlistener_netlistener_adaptor.go create mode 100644 rpc/rpc_client.go create mode 100644 rpc/rpc_debug.go create mode 100644 rpc/rpc_doc.go create mode 100644 rpc/rpc_logging.go create mode 100644 rpc/rpc_mux.go create mode 100644 rpc/rpc_server.go create mode 100644 rpc/transportmux/transportmux.go create mode 100644 rpc/versionhandshake/versionhandshake.go rename daemon/transport/handshake_test.go => rpc/versionhandshake/versionhandshake_test.go (99%) create mode 100644 rpc/versionhandshake/versionhandshake_transport_wrappers.go create mode 100644 transport/fromconfig/transport_fromconfig.go rename {daemon/transport/connecter => transport/local}/connect_local.go (72%) rename {daemon/transport/serve => transport/local}/serve_local.go (74%) rename {daemon/transport/connecter => transport/ssh}/connect_ssh.go (72%) rename {daemon/transport/serve => transport/ssh}/serve_stdinserver.go (57%) rename {daemon/transport/connecter => transport/tcp}/connect_tcp.go (54%) rename {daemon/transport/serve => transport/tcp}/serve_tcp.go (68%) rename {daemon/transport/connecter => transport/tls}/connect_tls.go (73%) create mode 100644 transport/tls/serve_tls.go create mode 100644 transport/tls/tls_wire_adaptor.go create mode 100644 transport/transport.go create mode 100644 util/bytecounter/bytecounter_streamcopier.go create mode 100644 util/bytecounter/bytecounter_streamcopier_test.go create mode 100644 zfs/zfs_debug.go create mode 100644 zfs/zfs_pipe.go create mode 100644 zfs/zfs_pipe_linux.go diff --git a/.travis.yml b/.travis.yml index 2321376..8ced506 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ matrix: # all go entries vary only by go version - language: go go: - - "1.10" + - "1.11" go_import_path: github.com/zrepl/zrepl before_install: - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip @@ -36,7 +36,7 @@ matrix: - language: go go: - - "1.11" + - "master" go_import_path: github.com/zrepl/zrepl before_install: - wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip diff --git a/Gopkg.lock b/Gopkg.lock index a11b64a..c670de5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -80,6 +80,10 @@ "protoc-gen-go/generator/internal/remap", "protoc-gen-go/grpc", "protoc-gen-go/plugin", + "ptypes", + "ptypes/any", + "ptypes/duration", + "ptypes/timestamp", ] pruneopts = "" revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5" @@ -183,30 +187,11 @@ [[projects]] branch = "master" - digest = "1:25559b520313b941b1395cd5d5ee66086b27dc15a1391c0f2aad29d5c2321f4b" + digest = "1:fa72f780ae3b4820ed12cef7a034291ab10d83e2da4ab5ba81afa44d5bf3a529" name = "github.com/problame/go-netssh" packages = ["."] pruneopts = "" - revision = "c56ad38d2c91397ad3c8dd9443d7448e328a9e9e" - -[[projects]] - branch = "master" - digest = "1:8c63c44f018bd52b03ebad65c9df26aabbc6793138e421df1c8c84c285a45bc6" - name = "github.com/problame/go-rwccmd" - packages = ["."] - pruneopts = "" - revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7" - -[[projects]] - digest = "1:1bcbb0a7ad8d3392d446eb583ae5415ff987838a8f7331a36877789be20667e6" - name = "github.com/problame/go-streamrpc" - packages = [ - ".", - "internal/pdu", - ] - pruneopts = "" - revision = "d5d111e014342fe1c37f0b71cc37ec5f2afdfd13" - version = "v0.5" + revision = "09d6bc45d284784cb3e5aaa1998513f37eb19cc6" [[projects]] branch = "master" @@ -279,6 +264,14 @@ revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" version = "v1.1.4" +[[projects]] + digest = "1:f80053a92d9ac874ad97d665d3005c1118ed66e9e88401361dc32defe6bef21c" + name = "github.com/theckman/goconstraint" + packages = ["go1.11/gte"] + pruneopts = "" + revision = "93babf24513d0e8277635da8169fcc5a46ae3f6a" + version = "v1.11.0" + [[projects]] branch = "v2" digest = "1:6b8a6afafde7ed31cd0c577ba40d88ce39e8f1c5eb76d7836be7d5b74f1c534a" @@ -289,21 +282,48 @@ [[projects]] branch = "master" - digest = "1:9c286cf11d0ca56368185bada5dd6d97b6be4648fc26c354fcba8df7293718f7" + digest = "1:ea539c13b066dac72a940b62f37600a20ab8e88057397c78f3197c1a48475425" + name = "golang.org/x/net" + packages = [ + "context", + "http/httpguts", + "http2", + "http2/hpack", + "idna", + "internal/timeseries", + "trace", + ] + pruneopts = "" + revision = "351d144fa1fc0bd934e2408202be0c29f25e35a0" + +[[projects]] + branch = "master" + digest = "1:f358024b019f87eecaadcb098113a40852c94fe58ea670ef3c3e2d2c7bd93db1" name = "golang.org/x/sys" packages = ["unix"] pruneopts = "" - revision = "bf42f188b9bc6f2cf5b8ee5a912ef1aedd0eba4c" + revision = "4ed8d59d0b35e1e29334a206d1b3f38b1e5dfb31" [[projects]] digest = "1:5acd3512b047305d49e8763eef7ba423901e85d5dd2fd1e71778a0ea8de10bd4" name = "golang.org/x/text" packages = [ + "collate", + "collate/build", "encoding", "encoding/internal/identifier", + "internal/colltab", "internal/gen", + "internal/tag", + "internal/triegen", + "internal/ucd", + "language", + "secure/bidirule", "transform", + "unicode/bidi", "unicode/cldr", + "unicode/norm", + "unicode/rangetable", ] pruneopts = "" revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" @@ -317,6 +337,54 @@ pruneopts = "" revision = "d0ca3933b724e6be513276cc2edb34e10d667438" +[[projects]] + branch = "master" + digest = "1:5fc6c317675b746d0c641b29aa0aab5fcb403c0d07afdbf0de86b0d447a0502a" + name = "google.golang.org/genproto" + packages = ["googleapis/rpc/status"] + pruneopts = "" + revision = "bd91e49a0898e27abb88c339b432fa53d7497ac0" + +[[projects]] + digest = "1:d141efe4aaad714e3059c340901aab3147b6253e58c85dafbcca3dd8b0e88ad6" + name = "google.golang.org/grpc" + packages = [ + ".", + "balancer", + "balancer/base", + "balancer/roundrobin", + "binarylog/grpc_binarylog_v1", + "codes", + "connectivity", + "credentials", + "credentials/internal", + "encoding", + "encoding/proto", + "grpclog", + "internal", + "internal/backoff", + "internal/binarylog", + "internal/channelz", + "internal/envconfig", + "internal/grpcrand", + "internal/grpcsync", + "internal/syscall", + "internal/transport", + "keepalive", + "metadata", + "naming", + "peer", + "resolver", + "resolver/dns", + "resolver/passthrough", + "stats", + "status", + "tap", + ] + pruneopts = "" + revision = "df014850f6dee74ba2fc94874043a9f3f75fbfd8" + version = "v1.17.0" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 @@ -332,8 +400,6 @@ "github.com/mattn/go-isatty", "github.com/pkg/errors", "github.com/problame/go-netssh", - "github.com/problame/go-rwccmd", - "github.com/problame/go-streamrpc", "github.com/prometheus/client_golang/prometheus", "github.com/prometheus/client_golang/prometheus/promhttp", "github.com/spf13/cobra", @@ -341,8 +407,13 @@ "github.com/stretchr/testify/assert", "github.com/stretchr/testify/require", "github.com/zrepl/yaml-config", + "golang.org/x/net/context", "golang.org/x/sys/unix", "golang.org/x/tools/cmd/stringer", + "google.golang.org/grpc", + "google.golang.org/grpc/credentials", + "google.golang.org/grpc/keepalive", + "google.golang.org/grpc/peer", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index c213f42..01e2aae 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,14 +1,12 @@ -ignored = [ "github.com/inconshreveable/mousetrap" ] +ignored = [ + "github.com/inconshreveable/mousetrap", +] required = [ "golang.org/x/tools/cmd/stringer", "github.com/alvaroloes/enumer", ] -[[constraint]] - branch = "master" - name = "github.com/ftrvxmtrx/fd" - [[constraint]] branch = "master" name = "github.com/jinzhu/copier" @@ -17,14 +15,6 @@ required = [ branch = "master" name = "github.com/kr/pretty" -[[constraint]] - branch = "master" - name = "github.com/mitchellh/go-homedir" - -[[constraint]] - branch = "master" - name = "github.com/mitchellh/mapstructure" - [[constraint]] name = "github.com/pkg/errors" version = "0.8.0" @@ -33,10 +23,6 @@ required = [ branch = "master" name = "github.com/spf13/cobra" -[[constraint]] - name = "github.com/spf13/viper" - version = "1.0.0" - [[constraint]] name = "github.com/stretchr/testify" version = "1.1.4" @@ -49,10 +35,6 @@ required = [ name = "github.com/go-logfmt/logfmt" version = "*" -[[constraint]] - name = "github.com/problame/go-rwccmd" - branch = "master" - [[constraint]] name = "github.com/problame/go-netssh" branch = "master" @@ -63,18 +45,17 @@ required = [ [[constraint]] name = "github.com/golang/protobuf" - version = "1.2.0" + version = "1" [[constraint]] name = "github.com/fatih/color" version = "1.7.0" -[[constraint]] - name = "github.com/problame/go-streamrpc" - version = "0.5.0" - - [[constraint]] name = "github.com/gdamore/tcell" version = "1.0.0" + +[[constraint]] + name = "google.golang.org/grpc" + version = "1" diff --git a/Makefile b/Makefile index 343af9b..120eb28 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ vendordeps: dep ensure -v -vendor-only generate: #not part of the build, must do that manually - protoc -I=replication/pdu --go_out=replication/pdu replication/pdu/pdu.proto + protoc -I=replication/pdu --go_out=plugins=grpc:replication/pdu replication/pdu/pdu.proto go generate -x ./... build: @@ -38,7 +38,10 @@ test: go test ./... vet: - go vet ./... + # for each supported platform to cover conditional compilation + GOOS=linux go vet ./... + GOOS=darwin go vet ./... + GOOS=freebsd go vet ./... $(ARTIFACTDIR): mkdir -p "$@" diff --git a/config/config.go b/config/config.go index effa8f8..3c2131a 100644 --- a/config/config.go +++ b/config/config.go @@ -130,7 +130,6 @@ type Global struct { Monitoring []MonitoringEnum `yaml:"monitoring,optional"` Control *GlobalControl `yaml:"control,optional,fromdefaults"` Serve *GlobalServe `yaml:"serve,optional,fromdefaults"` - RPC *RPCConfig `yaml:"rpc,optional,fromdefaults"` } func Default(i interface{}) { @@ -145,23 +144,12 @@ func Default(i interface{}) { } } -type RPCConfig struct { - Timeout time.Duration `yaml:"timeout,optional,positive,default=10s"` - TxChunkSize uint32 `yaml:"tx_chunk_size,optional,default=32768"` - RxStructuredMaxLen uint32 `yaml:"rx_structured_max,optional,default=16777216"` - RxStreamChunkMaxLen uint32 `yaml:"rx_stream_chunk_max,optional,default=16777216"` - RxHeaderMaxLen uint32 `yaml:"rx_header_max,optional,default=40960"` - SendHeartbeatInterval time.Duration `yaml:"send_heartbeat_interval,optional,positive,default=5s"` - -} - type ConnectEnum struct { Ret interface{} } type ConnectCommon struct { - Type string `yaml:"type"` - RPC *RPCConfig `yaml:"rpc,optional"` + Type string `yaml:"type"` } type TCPConnect struct { @@ -203,8 +191,7 @@ type ServeEnum struct { } type ServeCommon struct { - Type string `yaml:"type"` - RPC *RPCConfig `yaml:"rpc,optional"` + Type string `yaml:"type"` } type TCPServe struct { diff --git a/config/config_rpc_test.go b/config/config_rpc_test.go deleted file mode 100644 index f02311e..0000000 --- a/config/config_rpc_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package config - -import ( - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestRPC(t *testing.T) { - conf := testValidConfig(t, ` -jobs: -- name: pull_servers - type: pull - connect: - type: tcp - address: "server1.foo.bar:8888" - rpc: - timeout: 20s # different form default, should merge - root_fs: "pool2/backup_servers" - interval: 10m - pruning: - keep_sender: - - type: not_replicated - keep_receiver: - - type: last_n - count: 100 - -- name: pull_servers2 - type: pull - connect: - type: tcp - address: "server1.foo.bar:8888" - rpc: - tx_chunk_size: 0xabcd # different from default, should merge - root_fs: "pool2/backup_servers" - interval: 10m - pruning: - keep_sender: - - type: not_replicated - keep_receiver: - - type: last_n - count: 100 - -- type: sink - name: "laptop_sink" - root_fs: "pool2/backup_laptops" - serve: - type: tcp - listen: "192.168.122.189:8888" - clients: { - "10.23.42.23":"client1" - } - rpc: - rx_structured_max: 0x2342 - -- type: sink - name: "other_sink" - root_fs: "pool2/backup_laptops" - serve: - type: tcp - listen: "192.168.122.189:8888" - clients: { - "10.23.42.23":"client1" - } - rpc: - send_heartbeat_interval: 10s - -`) - - assert.Equal(t, 20*time.Second, conf.Jobs[0].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.Timeout) - assert.Equal(t, uint32(0xabcd), conf.Jobs[1].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.TxChunkSize) - assert.Equal(t, uint32(0x2342), conf.Jobs[2].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.RxStructuredMaxLen) - assert.Equal(t, 10*time.Second, conf.Jobs[3].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.SendHeartbeatInterval) - defConf := RPCConfig{} - Default(&defConf) - assert.Equal(t, defConf.Timeout, conf.Global.RPC.Timeout) -} - -func TestGlobal_DefaultRPCConfig(t *testing.T) { - assert.NotPanics(t, func() { - var c RPCConfig - Default(&c) - assert.NotNil(t, c) - assert.Equal(t, c.TxChunkSize, uint32(1)<<15) - }) -} diff --git a/config/config_test.go b/config/config_test.go index d8e1e2c..d51f3f7 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,11 +1,14 @@ package config import ( - "github.com/kr/pretty" - "github.com/stretchr/testify/require" + "bytes" "path" "path/filepath" "testing" + "text/template" + + "github.com/kr/pretty" + "github.com/stretchr/testify/require" ) func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) { @@ -35,8 +38,21 @@ func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) { } +// template must be a template/text template with a single '{{ . }}' as placehodler for val +func testValidConfigTemplate(t *testing.T, tmpl string, val string) *Config { + tmp, err := template.New("master").Parse(tmpl) + if err != nil { + panic(err) + } + var buf bytes.Buffer + err = tmp.Execute(&buf, val) + if err != nil { + panic(err) + } + return testValidConfig(t, buf.String()) +} -func testValidConfig(t *testing.T, input string) (*Config) { +func testValidConfig(t *testing.T, input string) *Config { t.Helper() conf, err := testConfig(t, input) require.NoError(t, err) @@ -47,4 +63,4 @@ func testValidConfig(t *testing.T, input string) (*Config) { func testConfig(t *testing.T, input string) (*Config, error) { t.Helper() return ParseConfigBytes([]byte(input)) -} \ No newline at end of file +} diff --git a/daemon/job/active.go b/daemon/job/active.go index 766736b..907f4be 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -2,9 +2,12 @@ package job import ( "context" + "sync" + "time" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/job/reset" @@ -12,32 +15,30 @@ import ( "github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/pruner" "github.com/zrepl/zrepl/daemon/snapper" - "github.com/zrepl/zrepl/daemon/transport/connecter" "github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/fromconfig" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/zfs" - "sync" - "time" ) type ActiveSide struct { - mode activeMode - name string - clientFactory *connecter.ClientFactory + mode activeMode + name string + connecter transport.Connecter prunerFactory *pruner.PrunerFactory - - promRepStateSecs *prometheus.HistogramVec // labels: state - promPruneSecs *prometheus.HistogramVec // labels: prune_side - promBytesReplicated *prometheus.CounterVec // labels: filesystem + promRepStateSecs *prometheus.HistogramVec // labels: state + promPruneSecs *prometheus.HistogramVec // labels: prune_side + promBytesReplicated *prometheus.CounterVec // labels: filesystem tasksMtx sync.Mutex tasks activeSideTasks } - //go:generate enumer -type=ActiveSideState type ActiveSideState int @@ -48,12 +49,11 @@ const ( ActiveSideDone // also errors ) - type activeSideTasks struct { state ActiveSideState // valid for state ActiveSideReplicating, ActiveSidePruneSender, ActiveSidePruneReceiver, ActiveSideDone - replication *replication.Replication + replication *replication.Replication replicationCancel context.CancelFunc // valid for state ActiveSidePruneSender, ActiveSidePruneReceiver, ActiveSideDone @@ -77,28 +77,59 @@ func (a *ActiveSide) updateTasks(u func(*activeSideTasks)) activeSideTasks { } type activeMode interface { - SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) + ConnectEndpoints(rpcLoggers rpc.Loggers, connecter transport.Connecter) + DisconnectEndpoints() + SenderReceiver() (replication.Sender, replication.Receiver) Type() Type RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}) + ResetConnectBackoff() } type modePush struct { - fsfilter endpoint.FSFilter - snapper *snapper.PeriodicOrManual + setupMtx sync.Mutex + sender *endpoint.Sender + receiver *rpc.Client + fsfilter endpoint.FSFilter + snapper *snapper.PeriodicOrManual } -func (m *modePush) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { - sender := endpoint.NewSender(m.fsfilter) - receiver := endpoint.NewRemote(client) - return sender, receiver, nil +func (m *modePush) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil || m.sender != nil { + panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints") + } + m.sender = endpoint.NewSender(m.fsfilter) + m.receiver = rpc.NewClient(connecter, loggers) +} + +func (m *modePush) DisconnectEndpoints() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + m.receiver.Close() + m.sender = nil + m.receiver = nil +} + +func (m *modePush) SenderReceiver() (replication.Sender, replication.Receiver) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + return m.sender, m.receiver } func (m *modePush) Type() Type { return TypePush } -func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan <- struct{}) { +func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}) { m.snapper.Run(ctx, wakeUpCommon) } +func (m *modePush) ResetConnectBackoff() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil { + m.receiver.ResetConnectBackoff() + } +} func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) { m := &modePush{} @@ -116,14 +147,35 @@ func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) } type modePull struct { + setupMtx sync.Mutex + receiver *endpoint.Receiver + sender *rpc.Client rootFS *zfs.DatasetPath interval time.Duration } -func (m *modePull) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) { - sender := endpoint.NewRemote(client) - receiver, err := endpoint.NewReceiver(m.rootFS) - return sender, receiver, err +func (m *modePull) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.receiver != nil || m.sender != nil { + panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints") + } + m.receiver = endpoint.NewReceiver(m.rootFS, false) + m.sender = rpc.NewClient(connecter, loggers) +} + +func (m *modePull) DisconnectEndpoints() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + m.sender.Close() + m.sender = nil + m.receiver = nil +} + +func (m *modePull) SenderReceiver() (replication.Sender, replication.Receiver) { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + return m.sender, m.receiver } func (*modePull) Type() Type { return TypePull } @@ -148,6 +200,14 @@ func (m *modePull) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{} } } +func (m *modePull) ResetConnectBackoff() { + m.setupMtx.Lock() + defer m.setupMtx.Unlock() + if m.sender != nil { + m.sender.ResetConnectBackoff() + } +} + func modePullFromConfig(g *config.Global, in *config.PullJob) (m *modePull, err error) { m = &modePull{} if in.Interval <= 0 { @@ -175,17 +235,17 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act Subsystem: "replication", Name: "state_time", Help: "seconds spent during replication", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"state"}) j.promBytesReplicated = prometheus.NewCounterVec(prometheus.CounterOpts{ Namespace: "zrepl", Subsystem: "replication", Name: "bytes_replicated", Help: "number of bytes replicated from sender to receiver per filesystem", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"filesystem"}) - j.clientFactory, err = connecter.FromConfig(g, in.Connect) + j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect) if err != nil { return nil, errors.Wrap(err, "cannot build client") } @@ -195,7 +255,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act Subsystem: "pruning", Name: "time", Help: "seconds spent in pruner", - ConstLabels: prometheus.Labels{"zrepl_job":j.name}, + ConstLabels: prometheus.Labels{"zrepl_job": j.name}, }, []string{"prune_side"}) j.prunerFactory, err = pruner.NewPrunerFactory(in.Pruning, j.promPruneSecs) if err != nil { @@ -214,7 +274,7 @@ func (j *ActiveSide) RegisterMetrics(registerer prometheus.Registerer) { func (j *ActiveSide) Name() string { return j.name } type ActiveSideStatus struct { - Replication *replication.Report + Replication *replication.Report PruningSender, PruningReceiver *pruner.Report } @@ -256,6 +316,7 @@ outer: break outer case <-wakeup.Wait(ctx): + j.mode.ResetConnectBackoff() case <-periodicDone: } invocationCount++ @@ -268,6 +329,9 @@ func (j *ActiveSide) do(ctx context.Context) { log := GetLogger(ctx) ctx = logging.WithSubsystemLoggers(ctx, log) + loggers := rpc.GetLoggersOrPanic(ctx) // filled by WithSubsystemLoggers + j.mode.ConnectEndpoints(loggers, j.connecter) + defer j.mode.DisconnectEndpoints() // allow cancellation of an invocation (this function) ctx, cancelThisRun := context.WithCancel(ctx) @@ -353,13 +417,7 @@ func (j *ActiveSide) do(ctx context.Context) { } }() - client, err := j.clientFactory.NewClient() - if err != nil { - log.WithError(err).Error("factory cannot instantiate streamrpc client") - } - defer client.Close(ctx) - - sender, receiver, err := j.mode.SenderReceiver(client) + sender, receiver := j.mode.SenderReceiver() { select { diff --git a/daemon/job/passive.go b/daemon/job/passive.go index 99071a8..aae3f5a 100644 --- a/daemon/job/passive.go +++ b/daemon/job/passive.go @@ -2,28 +2,30 @@ package job import ( "context" + "fmt" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/logging" - "github.com/zrepl/zrepl/daemon/transport/serve" "github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/endpoint" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/fromconfig" "github.com/zrepl/zrepl/zfs" - "path" ) type PassiveSide struct { - mode passiveMode - name string - l serve.ListenerFactory - rpcConf *streamrpc.ConnConfig + mode passiveMode + name string + listen transport.AuthenticatedListenerFactory } type passiveMode interface { - ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc + Handler() rpc.Handler RunPeriodic(ctx context.Context) Type() Type } @@ -34,26 +36,8 @@ type modeSink struct { func (m *modeSink) Type() Type { return TypeSink } -func (m *modeSink) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { - log := GetLogger(ctx) - - clientRootStr := path.Join(m.rootDataset.ToString(), conn.ClientIdentity()) - clientRoot, err := zfs.NewDatasetPath(clientRootStr) - if err != nil { - log.WithError(err). - WithField("client_identity", conn.ClientIdentity()). - Error("cannot build client filesystem map (client identity must be a valid ZFS FS name") - } - log.WithField("client_root", clientRoot).Debug("client root") - - local, err := endpoint.NewReceiver(clientRoot) - if err != nil { - log.WithError(err).Error("unexpected error: cannot convert mapping to filter") - return nil - } - - h := endpoint.NewHandler(local) - return h.Handle +func (m *modeSink) Handler() rpc.Handler { + return endpoint.NewReceiver(m.rootDataset, true) } func (m *modeSink) RunPeriodic(_ context.Context) {} @@ -72,7 +56,7 @@ func modeSinkFromConfig(g *config.Global, in *config.SinkJob) (m *modeSink, err type modeSource struct { fsfilter zfs.DatasetFilter - snapper *snapper.PeriodicOrManual + snapper *snapper.PeriodicOrManual } func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource, err error) { @@ -93,10 +77,8 @@ func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource func (m *modeSource) Type() Type { return TypeSource } -func (m *modeSource) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc { - sender := endpoint.NewSender(m.fsfilter) - h := endpoint.NewHandler(sender) - return h.Handle +func (m *modeSource) Handler() rpc.Handler { + return endpoint.NewSender(m.fsfilter) } func (m *modeSource) RunPeriodic(ctx context.Context) { @@ -106,8 +88,8 @@ func (m *modeSource) RunPeriodic(ctx context.Context) { func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passiveMode) (s *PassiveSide, err error) { s = &PassiveSide{mode: mode, name: in.Name} - if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil { - return nil, errors.Wrap(err, "cannot build server") + if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil { + return nil, errors.Wrap(err, "cannot build listener factory") } return s, nil @@ -115,7 +97,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passive func (j *PassiveSide) Name() string { return j.name } -type PassiveStatus struct {} +type PassiveStatus struct{} func (s *PassiveSide) Status() *Status { return &Status{Type: s.mode.Type()} // FIXME PassiveStatus @@ -127,70 +109,30 @@ func (j *PassiveSide) Run(ctx context.Context) { log := GetLogger(ctx) defer log.Info("job exiting") - - l, err := j.l.Listen() - if err != nil { - log.WithError(err).Error("cannot listen") - return - } - defer l.Close() - + ctx = logging.WithSubsystemLoggers(ctx, log) { - ctx, cancel := context.WithCancel(logging.WithSubsystemLoggers(ctx, log)) // shadowing + ctx, cancel := context.WithCancel(ctx) // shadowing defer cancel() go j.mode.RunPeriodic(ctx) } - log.WithField("addr", l.Addr()).Debug("accepting connections") - var connId int -outer: - for { - - select { - case res := <-accept(ctx, l): - if res.err != nil { - log.WithError(res.err).Info("accept error") - continue - } - conn := res.conn - connId++ - connLog := log. - WithField("connID", connId) - connLog. - WithField("addr", conn.RemoteAddr()). - WithField("client_identity", conn.ClientIdentity()). - Info("handling connection") - go func() { - defer connLog.Info("finished handling connection") - defer conn.Close() - ctx := logging.WithSubsystemLoggers(ctx, connLog) - handleFunc := j.mode.ConnHandleFunc(ctx, conn) - if handleFunc == nil { - return - } - if err := streamrpc.ServeConn(ctx, conn, j.rpcConf, handleFunc); err != nil { - log.WithError(err).Error("error serving client") - } - }() - - case <-ctx.Done(): - break outer - } - + handler := j.mode.Handler() + if handler == nil { + panic(fmt.Sprintf("implementation error: j.mode.Handler() returned nil: %#v", j)) } -} + ctxInterceptor := func(handlerCtx context.Context) context.Context { + return logging.WithSubsystemLoggers(handlerCtx, log) + } -type acceptResult struct { - conn serve.AuthenticatedConn - err error -} + rpcLoggers := rpc.GetLoggersOrPanic(ctx) // WithSubsystemLoggers above + server := rpc.NewServer(handler, rpcLoggers, ctxInterceptor) -func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult { - c := make(chan acceptResult, 1) - go func() { - conn, err := listener.Accept(ctx) - c <- acceptResult{conn, err} - }() - return c + listener, err := j.listen() + if err != nil { + log.WithError(err).Error("cannot listen") + return + } + + server.Serve(ctx, listener) } diff --git a/daemon/logging/adaptors.go b/daemon/logging/adaptors.go deleted file mode 100644 index c5a7196..0000000 --- a/daemon/logging/adaptors.go +++ /dev/null @@ -1,32 +0,0 @@ -package logging - -import ( - "fmt" - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/logger" - "strings" -) - -type streamrpcLogAdaptor = twoClassLogAdaptor - -type twoClassLogAdaptor struct { - logger.Logger -} - -var _ streamrpc.Logger = twoClassLogAdaptor{} - -func (a twoClassLogAdaptor) Errorf(fmtStr string, args ...interface{}) { - const errorSuffix = ": %s" - if len(args) == 1 { - if err, ok := args[0].(error); ok && strings.HasSuffix(fmtStr, errorSuffix) { - msg := strings.TrimSuffix(fmtStr, errorSuffix) - a.WithError(err).Error(msg) - return - } - } - a.Logger.Error(fmt.Sprintf(fmtStr, args...)) -} - -func (a twoClassLogAdaptor) Infof(fmtStr string, args ...interface{}) { - a.Logger.Debug(fmt.Sprintf(fmtStr, args...)) -} diff --git a/daemon/logging/build_logging.go b/daemon/logging/build_logging.go index fcc1fa4..bcefec5 100644 --- a/daemon/logging/build_logging.go +++ b/daemon/logging/build_logging.go @@ -4,18 +4,21 @@ import ( "context" "crypto/tls" "crypto/x509" + "os" + "github.com/mattn/go-isatty" "github.com/pkg/errors" - "github.com/problame/go-streamrpc" + "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/pruner" + "github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/rpc" + "github.com/zrepl/zrepl/rpc/transportmux" "github.com/zrepl/zrepl/tlsconf" - "os" - "github.com/zrepl/zrepl/daemon/snapper" - "github.com/zrepl/zrepl/daemon/transport/serve" + "github.com/zrepl/zrepl/transport" ) func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) { @@ -60,22 +63,41 @@ func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) } +type Subsystem string + const ( - SubsysReplication = "repl" - SubsysStreamrpc = "rpc" - SubsyEndpoint = "endpoint" + SubsysReplication Subsystem = "repl" + SubsyEndpoint Subsystem = "endpoint" + SubsysPruning Subsystem = "pruning" + SubsysSnapshot Subsystem = "snapshot" + SubsysTransport Subsystem = "transport" + SubsysTransportMux Subsystem = "transportmux" + SubsysRPC Subsystem = "rpc" + SubsysRPCControl Subsystem = "rpc.ctrl" + SubsysRPCData Subsystem = "rpc.data" ) func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Context { - ctx = replication.WithLogger(ctx, log.WithField(SubsysField, "repl")) - ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{log.WithField(SubsysField, "rpc")}) - ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint")) - ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning")) - ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot")) - ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve")) + ctx = replication.WithLogger(ctx, log.WithField(SubsysField, SubsysReplication)) + ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, SubsyEndpoint)) + ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, SubsysPruning)) + ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, SubsysSnapshot)) + ctx = transport.WithLogger(ctx, log.WithField(SubsysField, SubsysTransport)) + ctx = transportmux.WithLogger(ctx, log.WithField(SubsysField, SubsysTransportMux)) + ctx = rpc.WithLoggers(ctx, + rpc.Loggers{ + General: log.WithField(SubsysField, SubsysRPC), + Control: log.WithField(SubsysField, SubsysRPCControl), + Data: log.WithField(SubsysField, SubsysRPCData), + }, + ) return ctx } +func LogSubsystem(log logger.Logger, subsys Subsystem) logger.Logger { + return log.ReplaceField(SubsysField, subsys) +} + func parseLogFormat(i interface{}) (f EntryFormatter, err error) { var is string switch j := i.(type) { diff --git a/daemon/prometheus.go b/daemon/prometheus.go index 7607b94..1f10e9d 100644 --- a/daemon/prometheus.go +++ b/daemon/prometheus.go @@ -7,6 +7,7 @@ import ( "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/job" "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" "github.com/zrepl/zrepl/zfs" "net" "net/http" @@ -49,6 +50,10 @@ func (j *prometheusJob) Run(ctx context.Context) { panic(err) } + if err := frameconn.PrometheusRegister(prometheus.DefaultRegisterer); err != nil { + panic(err) + } + log := job.GetLogger(ctx) l, err := net.Listen("tcp", j.listen) diff --git a/daemon/pruner/pruner.go b/daemon/pruner/pruner.go index d5705db..bb515ad 100644 --- a/daemon/pruner/pruner.go +++ b/daemon/pruner/pruner.go @@ -11,7 +11,6 @@ import ( "github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/watchdog" - "github.com/problame/go-streamrpc" "net" "sort" "strings" @@ -19,14 +18,15 @@ import ( "time" ) -// Try to keep it compatible with gitub.com/zrepl/zrepl/replication.Endpoint +// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint type History interface { ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) } +// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint type Target interface { - ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) - ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS + ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) } @@ -346,7 +346,6 @@ type Error interface { } var _ Error = net.Error(nil) -var _ Error = streamrpc.Error(nil) func shouldRetry(e error) bool { if neterr, ok := e.(net.Error); ok { @@ -381,10 +380,11 @@ func statePlan(a *args, u updater) state { ka = &pruner.Progress }) - tfss, err := target.ListFilesystems(ctx) + tfssres, err := target.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { return onErr(u, err) } + tfss := tfssres.GetFilesystems() pfss := make([]*fs, len(tfss)) for i, tfs := range tfss { @@ -398,11 +398,12 @@ func statePlan(a *args, u updater) state { } pfss[i] = pfs - tfsvs, err := target.ListFilesystemVersions(ctx, tfs.Path) + tfsvsres, err := target.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: tfs.Path}) if err != nil { l.WithError(err).Error("cannot list filesystem versions") return onErr(u, err) } + tfsvs := tfsvsres.GetVersions() // no progress here since we could run in a live-lock (must have used target AND receiver before progress) pfs.snaps = make([]pruning.Snapshot, 0, len(tfsvs)) diff --git a/daemon/pruner/pruner_test.go b/daemon/pruner/pruner_test.go index 47d8f41..23a10e8 100644 --- a/daemon/pruner/pruner_test.go +++ b/daemon/pruner/pruner_test.go @@ -44,7 +44,7 @@ type mockTarget struct { destroyErrs map[string][]error } -func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { +func (t *mockTarget) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { if len(t.listFilesystemsErr) > 0 { e := t.listFilesystemsErr[0] t.listFilesystemsErr = t.listFilesystemsErr[1:] @@ -54,10 +54,11 @@ func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, er for i := range fss { fss[i] = t.fss[i].Filesystem() } - return fss, nil + return &pdu.ListFilesystemRes{Filesystems: fss}, nil } -func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { +func (t *mockTarget) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + fs := req.Filesystem if len(t.listVersionsErrs[fs]) != 0 { e := t.listVersionsErrs[fs][0] t.listVersionsErrs[fs] = t.listVersionsErrs[fs][1:] @@ -68,7 +69,7 @@ func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]* if mfs.path != fs { continue } - return mfs.FilesystemVersions(), nil + return &pdu.ListFilesystemVersionsRes{Versions: mfs.FilesystemVersions()}, nil } return nil, fmt.Errorf("filesystem %s does not exist", fs) } diff --git a/daemon/snapper/snapper.go b/daemon/snapper/snapper.go index 6cd5b98..c6cc9a8 100644 --- a/daemon/snapper/snapper.go +++ b/daemon/snapper/snapper.go @@ -177,7 +177,7 @@ func onMainCtxDone(ctx context.Context, u updater) state { } func syncUp(a args, u updater) state { - fss, err := listFSes(a.fsf) + fss, err := listFSes(a.ctx, a.fsf) if err != nil { return onErr(err, u) } @@ -204,7 +204,7 @@ func plan(a args, u updater) state { u(func(snapper *Snapper) { snapper.lastInvocation = time.Now() }) - fss, err := listFSes(a.fsf) + fss, err := listFSes(a.ctx, a.fsf) if err != nil { return onErr(err, u) } @@ -299,8 +299,8 @@ func wait(a args, u updater) state { } } -func listFSes(mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) { - return zfs.ZFSListMapping(mf) +func listFSes(ctx context.Context, mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) { + return zfs.ZFSListMapping(ctx, mf) } func findSyncPoint(log Logger, fss []*zfs.DatasetPath, prefix string, interval time.Duration) (syncPoint time.Time, err error) { diff --git a/daemon/streamrpcconfig/streamrpcconfig.go b/daemon/streamrpcconfig/streamrpcconfig.go deleted file mode 100644 index da28d5d..0000000 --- a/daemon/streamrpcconfig/streamrpcconfig.go +++ /dev/null @@ -1,25 +0,0 @@ -package streamrpcconfig - -import ( - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/config" -) - -func FromDaemonConfig(g *config.Global, in *config.RPCConfig) (*streamrpc.ConnConfig, error) { - conf := in - if conf == nil { - conf = g.RPC - } - srpcConf := &streamrpc.ConnConfig{ - RxHeaderMaxLen: conf.RxHeaderMaxLen, - RxStructuredMaxLen: conf.RxStructuredMaxLen, - RxStreamMaxChunkSize: conf.RxStreamChunkMaxLen, - TxChunkSize: conf.TxChunkSize, - Timeout: conf.Timeout, - SendHeartbeatInterval: conf.SendHeartbeatInterval, - } - if err := srpcConf.Validate(); err != nil { - return nil, err - } - return srpcConf, nil -} diff --git a/daemon/transport/connecter/connecter.go b/daemon/transport/connecter/connecter.go deleted file mode 100644 index fa772a7..0000000 --- a/daemon/transport/connecter/connecter.go +++ /dev/null @@ -1,84 +0,0 @@ -package connecter - -import ( - "context" - "fmt" - "github.com/problame/go-streamrpc" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/streamrpcconfig" - "github.com/zrepl/zrepl/daemon/transport" - "net" - "time" -) - - -type HandshakeConnecter struct { - connecter streamrpc.Connecter -} - -func (c HandshakeConnecter) Connect(ctx context.Context) (net.Conn, error) { - conn, err := c.connecter.Connect(ctx) - if err != nil { - return nil, err - } - dl, ok := ctx.Deadline() - if !ok { - dl = time.Now().Add(10 * time.Second) // FIXME constant - } - if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil { - conn.Close() - return nil, err - } - return conn, nil -} - - - -func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error) { - var ( - connecter streamrpc.Connecter - errConnecter, errRPC error - connConf *streamrpc.ConnConfig - ) - switch v := in.Ret.(type) { - case *config.SSHStdinserverConnect: - connecter, errConnecter = SSHStdinserverConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TCPConnect: - connecter, errConnecter = TCPConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TLSConnect: - connecter, errConnecter = TLSConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.LocalConnect: - connecter, errConnecter = LocalConnecterFromConfig(v) - connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC) - default: - panic(fmt.Sprintf("implementation error: unknown connecter type %T", v)) - } - - if errConnecter != nil { - return nil, errConnecter - } - if errRPC != nil { - return nil, errRPC - } - - config := streamrpc.ClientConfig{ConnConfig: connConf} - if err := config.Validate(); err != nil { - return nil, err - } - - connecter = HandshakeConnecter{connecter} - - return &ClientFactory{connecter: connecter, config: &config}, nil -} - -type ClientFactory struct { - connecter streamrpc.Connecter - config *streamrpc.ClientConfig -} - -func (f ClientFactory) NewClient() (*streamrpc.Client, error) { - return streamrpc.NewClient(f.connecter, f.config) -} diff --git a/daemon/transport/handshake.go b/daemon/transport/handshake.go deleted file mode 100644 index ecfd495..0000000 --- a/daemon/transport/handshake.go +++ /dev/null @@ -1,136 +0,0 @@ -package transport - -import ( - "bytes" - "fmt" - "io" - "net" - "strings" - "time" - "unicode/utf8" -) - -type HandshakeMessage struct { - ProtocolVersion int - Extensions []string -} - -func (m *HandshakeMessage) Encode() ([]byte, error) { - if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 { - return nil, fmt.Errorf("protocol version must be in [1, 9999]") - } - if len(m.Extensions) >= 9999 { - return nil, fmt.Errorf("protocol only supports [0, 9999] extensions") - } - // EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions - var extensions strings.Builder - for i, ext := range m.Extensions { - if strings.ContainsAny(ext, "\n") { - return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i) - } - if !utf8.ValidString(ext) { - return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i) - } - extensions.WriteString(ext) - extensions.WriteString("\n") - } - withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s", - m.ProtocolVersion, len(m.Extensions), extensions.String()) - withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen) - return []byte(withLen), nil -} - -func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error { - var lenAndSpace [11]byte - if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil { - return err - } - if !utf8.Valid(lenAndSpace[:]) { - return fmt.Errorf("invalid start of handshake message: not valid UTF-8") - } - var followLen int - n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen) - if n != 1 || err != nil { - return fmt.Errorf("could not parse handshake message length") - } - if followLen > maxLen { - return fmt.Errorf("handshake message length exceeds max length (%d vs %d)", - followLen, maxLen) - } - - var buf bytes.Buffer - _, err = io.Copy(&buf, io.LimitReader(r, int64(followLen))) - if err != nil { - return err - } - - var ( - protoVersion, extensionCount int - ) - n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n", - &protoVersion, &extensionCount) - if n != 2 || err != nil { - return fmt.Errorf("could not parse handshake message: %s", err) - } - if protoVersion < 1 { - return fmt.Errorf("invalid protocol version %q", protoVersion) - } - m.ProtocolVersion = protoVersion - - if extensionCount < 0 { - return fmt.Errorf("invalid extension count %q", extensionCount) - } - if extensionCount == 0 { - if buf.Len() != 0 { - return fmt.Errorf("unexpected data trailing after header") - } - m.Extensions = nil - return nil - } - s := buf.String() - if strings.Count(s, "\n") != extensionCount { - return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount) - } - exts := strings.Split(s, "\n") - if exts[len(exts)-1] != "" { - return fmt.Errorf("unexpected data trailing after last extension newline") - } - m.Extensions = exts[0:len(exts)-1] - - return nil -} - -func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error { - // current protocol version is hardcoded here - return DoHandshakeVersion(conn, deadline, 1) -} - -func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error { - ours := HandshakeMessage{ - ProtocolVersion: version, - Extensions: nil, - } - hsb, err := ours.Encode() - if err != nil { - return fmt.Errorf("could not encode protocol banner: %s", err) - } - - conn.SetDeadline(deadline) - _, err = io.Copy(conn, bytes.NewBuffer(hsb)) - if err != nil { - return fmt.Errorf("could not send protocol banner: %s", err) - } - - theirs := HandshakeMessage{} - if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant - return fmt.Errorf("could not decode protocol banner: %s", err) - } - - if theirs.ProtocolVersion != ours.ProtocolVersion { - return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d", - ours.ProtocolVersion, theirs.ProtocolVersion) - } - // ignore extensions, we don't use them - - return nil -} diff --git a/daemon/transport/serve/serve.go b/daemon/transport/serve/serve.go deleted file mode 100644 index c1b3bb1..0000000 --- a/daemon/transport/serve/serve.go +++ /dev/null @@ -1,147 +0,0 @@ -package serve - -import ( - "github.com/pkg/errors" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/transport" - "net" - "github.com/zrepl/zrepl/daemon/streamrpcconfig" - "github.com/problame/go-streamrpc" - "context" - "github.com/zrepl/zrepl/logger" - "github.com/zrepl/zrepl/zfs" - "time" -) - -type contextKey int - -const contextKeyLog contextKey = 0 - -type Logger = logger.Logger - -func WithLogger(ctx context.Context, log Logger) context.Context { - return context.WithValue(ctx, contextKeyLog, log) -} - -func getLogger(ctx context.Context) Logger { - if log, ok := ctx.Value(contextKeyLog).(Logger); ok { - return log - } - return logger.NewNullLogger() -} - -type AuthenticatedConn interface { - net.Conn - // ClientIdentity must be a string that satisfies ValidateClientIdentity - ClientIdentity() string -} - -// A client identity must be a single component in a ZFS filesystem path -func ValidateClientIdentity(in string) (err error) { - path, err := zfs.NewDatasetPath(in) - if err != nil { - return err - } - if path.Length() != 1 { - return errors.New("client identity must be a single path comonent (not empty, no '/')") - } - return nil -} - -type authConn struct { - net.Conn - clientIdentity string -} - -var _ AuthenticatedConn = authConn{} - -func (c authConn) ClientIdentity() string { - if err := ValidateClientIdentity(c.clientIdentity); err != nil { - panic(err) - } - return c.clientIdentity -} - -// like net.Listener, but with an AuthenticatedConn instead of net.Conn -type AuthenticatedListener interface { - Addr() (net.Addr) - Accept(ctx context.Context) (AuthenticatedConn, error) - Close() error -} - -type ListenerFactory interface { - Listen() (AuthenticatedListener, error) -} - -type HandshakeListenerFactory struct { - lf ListenerFactory -} - -func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := lf.lf.Listen() - if err != nil { - return nil, err - } - return HandshakeListener{l}, nil -} - -type HandshakeListener struct { - l AuthenticatedListener -} - -func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() } - -func (l HandshakeListener) Close() error { return l.l.Close() } - -func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - conn, err := l.l.Accept(ctx) - if err != nil { - return nil, err - } - dl, ok := ctx.Deadline() - if !ok { - dl = time.Now().Add(10*time.Second) // FIXME constant - } - if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil { - conn.Close() - return nil, err - } - return conn, nil -} - -func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) { - - var ( - lfError, rpcErr error - ) - switch v := in.Ret.(type) { - case *config.TCPServe: - lf, lfError = TCPListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.TLSServe: - lf, lfError = TLSListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.StdinserverServer: - lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - case *config.LocalServe: - lf, lfError = LocalListenerFactoryFromConfig(g, v) - conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) - default: - return nil, nil, errors.Errorf("internal error: unknown serve type %T", v) - } - - if lfError != nil { - return nil, nil, lfError - } - if rpcErr != nil { - return nil, nil, rpcErr - } - - lf = HandshakeListenerFactory{lf} - - return lf, conf, nil - -} - - diff --git a/daemon/transport/serve/serve_tls.go b/daemon/transport/serve/serve_tls.go deleted file mode 100644 index bc95e41..0000000 --- a/daemon/transport/serve/serve_tls.go +++ /dev/null @@ -1,83 +0,0 @@ -package serve - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "github.com/pkg/errors" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/tlsconf" - "net" - "time" - "context" -) - -type TLSListenerFactory struct { - address string - clientCA *x509.CertPool - serverCert tls.Certificate - handshakeTimeout time.Duration - clientCNs map[string]struct{} -} - -func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TLSListenerFactory, err error) { - lf = &TLSListenerFactory{ - address: in.Listen, - handshakeTimeout: in.HandshakeTimeout, - } - - if in.Ca == "" || in.Cert == "" || in.Key == "" { - return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") - } - - lf.clientCA, err = tlsconf.ParseCAFile(in.Ca) - if err != nil { - return nil, errors.Wrap(err, "cannot parse ca file") - } - - lf.serverCert, err = tls.LoadX509KeyPair(in.Cert, in.Key) - if err != nil { - return nil, errors.Wrap(err, "cannot parse cer/key pair") - } - - lf.clientCNs = make(map[string]struct{}, len(in.ClientCNs)) - for i, cn := range in.ClientCNs { - if err := ValidateClientIdentity(cn); err != nil { - return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn) - } - // dupes are ok fr now - lf.clientCNs[cn] = struct{}{} - } - - return lf, nil -} - -func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := net.Listen("tcp", f.address) - if err != nil { - return nil, err - } - tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout) - return tlsAuthListener{tl, f.clientCNs}, nil -} - -type tlsAuthListener struct { - *tlsconf.ClientAuthListener - clientCNs map[string]struct{} -} - -func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - c, cn, err := l.ClientAuthListener.Accept() - if err != nil { - return nil, err - } - if _, ok := l.clientCNs[cn]; !ok { - if err := c.Close(); err != nil { - getLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name") - } - return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, c.RemoteAddr()) - } - return authConn{c, cn}, nil -} - - diff --git a/docs/configuration/transports.rst b/docs/configuration/transports.rst index a705bc9..0187775 100644 --- a/docs/configuration/transports.rst +++ b/docs/configuration/transports.rst @@ -201,13 +201,17 @@ The serve & connect configuration will thus look like the following: ``ssh+stdinserver`` Transport ----------------------------- -``ssh+stdinserver`` is inspired by `git shell `_ and `Borg Backup `_. -It is provided by the Go package ``github.com/problame/go-netssh``. +``ssh+stdinserver`` uses the ``ssh`` command and some features of the server-side SSH ``authorized_keys`` file. +It is less efficient than other transports because the data passes through two more pipes. +However, it is fairly convenient to set up and allows the zrepl daemon to not be directly exposed to the internet, because all traffic passes through the system's SSH server. -.. ATTENTION:: +The concept is inspired by `git shell `_ and `Borg Backup `_. +The implementation is provided by the Go package ``github.com/problame/go-netssh``. - ``ssh+stdinserver`` has inferior error detection and handling compared to the ``tcp`` and ``tls`` transports. - If you require tested timeout & retry handling, use ``tcp`` or ``tls`` transports, or help improve package go-netssh. +.. NOTE:: + + ``ssh+stdinserver`` generally provides inferior error detection and handling compared to the ``tcp`` and ``tls`` transports. + When encountering such problems, consider using ``tcp`` or ``tls`` transports, or help improve package go-netssh. .. _transport-ssh+stdinserver-serve: diff --git a/endpoint/context.go b/endpoint/context.go index 09f9032..20b3296 100644 --- a/endpoint/context.go +++ b/endpoint/context.go @@ -9,6 +9,7 @@ type contextKey int const ( contextKeyLogger contextKey = iota + ClientIdentityKey ) type Logger = logger.Logger diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 4b44825..b90f3f7 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -2,21 +2,19 @@ package endpoint import ( - "bytes" "context" "fmt" - "github.com/golang/protobuf/proto" + "path" + "github.com/pkg/errors" - "github.com/problame/go-streamrpc" "github.com/zrepl/zrepl/replication" "github.com/zrepl/zrepl/replication/pdu" "github.com/zrepl/zrepl/zfs" - "io" ) // Sender implements replication.ReplicationEndpoint for a sending side type Sender struct { - FSFilter zfs.DatasetFilter + FSFilter zfs.DatasetFilter } func NewSender(fsf zfs.DatasetFilter) *Sender { @@ -41,8 +39,8 @@ func (s *Sender) filterCheckFS(fs string) (*zfs.DatasetPath, error) { return dp, nil } -func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - fss, err := zfs.ZFSListMapping(p.FSFilter) +func (s *Sender) ListFilesystems(ctx context.Context, r *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + fss, err := zfs.ZFSListMapping(ctx, s.FSFilter) if err != nil { return nil, err } @@ -53,11 +51,12 @@ func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) // FIXME: not supporting ResumeToken yet } } - return rfss, nil + res := &pdu.ListFilesystemRes{Filesystems: rfss, Empty: len(rfss) == 0} + return res, nil } -func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - lp, err := p.filterCheckFS(fs) +func (s *Sender) ListFilesystemVersions(ctx context.Context, r *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + lp, err := s.filterCheckFS(r.GetFilesystem()) if err != nil { return nil, err } @@ -69,32 +68,36 @@ func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu. for i := range fsvs { rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) } - return rfsvs, nil + res := &pdu.ListFilesystemVersionsRes{Versions: rfsvs} + return res, nil + } -func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - _, err := p.filterCheckFS(r.Filesystem) +func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + _, err := s.filterCheckFS(r.Filesystem) if err != nil { return nil, nil, err } - if r.DryRun { - si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "") - if err != nil { - return nil, nil, err - } - var expSize int64 = 0 // protocol says 0 means no estimate - if si.SizeEstimate != -1 { // but si returns -1 for no size estimate - expSize = si.SizeEstimate - } - return &pdu.SendRes{ExpectedSize: expSize}, nil, nil - } else { - stream, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "") - if err != nil { - return nil, nil, err - } - return &pdu.SendRes{}, stream, nil + si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "") + if err != nil { + return nil, nil, err } + var expSize int64 = 0 // protocol says 0 means no estimate + if si.SizeEstimate != -1 { // but si returns -1 for no size estimate + expSize = si.SizeEstimate + } + res := &pdu.SendRes{ExpectedSize: expSize} + + if r.DryRun { + return res, nil, nil + } + + streamCopier, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "") + if err != nil { + return nil, nil, err + } + return res, streamCopier, nil } func (p *Sender) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { @@ -132,6 +135,10 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs } } +func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { + return nil, fmt.Errorf("sender does not implement Receive()") +} + type FSFilter interface { // FIXME unused Filter(path *zfs.DatasetPath) (pass bool, err error) } @@ -146,14 +153,50 @@ type FSMap interface { // FIXME unused // Receiver implements replication.ReplicationEndpoint for a receiving side type Receiver struct { - root *zfs.DatasetPath + rootWithoutClientComponent *zfs.DatasetPath + appendClientIdentity bool } -func NewReceiver(rootDataset *zfs.DatasetPath) (*Receiver, error) { +func NewReceiver(rootDataset *zfs.DatasetPath, appendClientIdentity bool) *Receiver { if rootDataset.Length() <= 0 { - return nil, errors.New("root dataset must not be an empty path") + panic(fmt.Sprintf("root dataset must not be an empty path: %v", rootDataset)) } - return &Receiver{root: rootDataset.Copy()}, nil + return &Receiver{rootWithoutClientComponent: rootDataset.Copy(), appendClientIdentity: appendClientIdentity} +} + +func TestClientIdentity(rootFS *zfs.DatasetPath, clientIdentity string) error { + _, err := clientRoot(rootFS, clientIdentity) + return err +} + +func clientRoot(rootFS *zfs.DatasetPath, clientIdentity string) (*zfs.DatasetPath, error) { + rootFSLen := rootFS.Length() + clientRootStr := path.Join(rootFS.ToString(), clientIdentity) + clientRoot, err := zfs.NewDatasetPath(clientRootStr) + if err != nil { + return nil, err + } + if rootFSLen+1 != clientRoot.Length() { + return nil, fmt.Errorf("client identity must be a single ZFS filesystem path component") + } + return clientRoot, nil +} + +func (s *Receiver) clientRootFromCtx(ctx context.Context) *zfs.DatasetPath { + if !s.appendClientIdentity { + return s.rootWithoutClientComponent.Copy() + } + + clientIdentity, ok := ctx.Value(ClientIdentityKey).(string) + if !ok { + panic(fmt.Sprintf("ClientIdentityKey context value must be set")) + } + + clientRoot, err := clientRoot(s.rootWithoutClientComponent, clientIdentity) + if err != nil { + panic(fmt.Sprintf("ClientIdentityContextKey must have been validated before invoking Receiver: %s", err)) + } + return clientRoot } type subroot struct { @@ -180,8 +223,9 @@ func (f subroot) MapToLocal(fs string) (*zfs.DatasetPath, error) { return c, nil } -func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - filtered, err := zfs.ZFSListMapping(subroot{e.root}) +func (s *Receiver) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + root := s.clientRootFromCtx(ctx) + filtered, err := zfs.ZFSListMapping(ctx, subroot{root}) if err != nil { return nil, err } @@ -194,19 +238,30 @@ func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, erro WithError(err). WithField("fs", a). Error("inconsistent placeholder property") - return nil, errors.New("server error, see logs") // don't leak path + return nil, errors.New("server error: inconsistent placeholder property") // don't leak path } if ph { + getLogger(ctx). + WithField("fs", a.ToString()). + Debug("ignoring placeholder filesystem") continue } - a.TrimPrefix(e.root) + getLogger(ctx). + WithField("fs", a.ToString()). + Debug("non-placeholder filesystem") + a.TrimPrefix(root) fss = append(fss, &pdu.Filesystem{Path: a.ToString()}) } - return fss, nil + if len(fss) == 0 { + getLogger(ctx).Debug("no non-placeholder filesystems") + return &pdu.ListFilesystemRes{Empty: true}, nil + } + return &pdu.ListFilesystemRes{Filesystems: fss}, nil } -func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - lp, err := subroot{e.root}.MapToLocal(fs) +func (s *Receiver) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.GetFilesystem()) if err != nil { return nil, err } @@ -221,18 +276,26 @@ func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pd rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i]) } - return rfsvs, nil + return &pdu.ListFilesystemVersionsRes{Versions: rfsvs}, nil } -func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error { - defer sendStream.Close() +func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { + return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver") +} - lp, err := subroot{e.root}.MapToLocal(req.Filesystem) - if err != nil { - return err - } +func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + return nil, nil, fmt.Errorf("receiver does not implement Send()") +} +func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { getLogger(ctx).Debug("incoming Receive") + defer receive.Close() + + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.Filesystem) + if err != nil { + return nil, err + } // create placeholder parent filesystems as appropriate var visitErr error @@ -261,7 +324,7 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream getLogger(ctx).WithField("visitErr", visitErr).Debug("complete tree-walk") if visitErr != nil { - return visitErr + return nil, err } needForceRecv := false @@ -279,19 +342,19 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream getLogger(ctx).Debug("start receive command") - if err := zfs.ZFSRecv(ctx, lp.ToString(), sendStream, args...); err != nil { + if err := zfs.ZFSRecv(ctx, lp.ToString(), receive, args...); err != nil { getLogger(ctx). WithError(err). WithField("args", args). Error("zfs receive failed") - sendStream.Close() - return err + return nil, err } - return nil + return &pdu.ReceiveRes{}, nil } -func (e *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { - lp, err := subroot{e.root}.MapToLocal(req.Filesystem) +func (s *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { + root := s.clientRootFromCtx(ctx) + lp, err := subroot{root}.MapToLocal(req.Filesystem) if err != nil { return nil, err } @@ -326,289 +389,3 @@ func doDestroySnapshots(ctx context.Context, lp *zfs.DatasetPath, snaps []*pdu.F } return res, nil } - -// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= -// RPC STUBS -// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= - -const ( - RPCListFilesystems = "ListFilesystems" - RPCListFilesystemVersions = "ListFilesystemVersions" - RPCReceive = "Receive" - RPCSend = "Send" - RPCSDestroySnapshots = "DestroySnapshots" - RPCReplicationCursor = "ReplicationCursor" -) - -// Remote implements an endpoint stub that uses streamrpc as a transport. -type Remote struct { - c *streamrpc.Client -} - -func NewRemote(c *streamrpc.Client) Remote { - return Remote{c} -} - -func (s Remote) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { - req := pdu.ListFilesystemReq{} - b, err := proto.Marshal(&req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ListFilesystemRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return res.Filesystems, nil -} - -func (s Remote) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { - req := pdu.ListFilesystemVersionsReq{ - Filesystem: fs, - } - b, err := proto.Marshal(&req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ListFilesystemVersionsRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return res.Versions, nil -} - -func (s Remote) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - b, err := proto.Marshal(r) - if err != nil { - return nil, nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil) - if err != nil { - return nil, nil, err - } - if !r.DryRun && rs == nil { - return nil, nil, errors.New("response does not contain a stream") - } - if r.DryRun && rs != nil { - rs.Close() - return nil, nil, errors.New("response contains unexpected stream (was dry run)") - } - var res pdu.SendRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - rs.Close() - return nil, nil, err - } - return &res, rs, nil -} - -func (s Remote) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error { - defer sendStream.Close() - b, err := proto.Marshal(r) - if err != nil { - return err - } - rb, rs, err := s.c.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream) - getLogger(ctx).WithField("err", err).Debug("Remote.Receive RequestReplyReturned") - if err != nil { - return err - } - if rs != nil { - rs.Close() - return errors.New("response contains unexpected stream") - } - var res pdu.ReceiveRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return err - } - return nil -} - -func (s Remote) DestroySnapshots(ctx context.Context, r *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { - b, err := proto.Marshal(r) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCSDestroySnapshots, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.DestroySnapshotsRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return &res, nil -} - -func (s Remote) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { - b, err := proto.Marshal(req) - if err != nil { - return nil, err - } - rb, rs, err := s.c.RequestReply(ctx, RPCReplicationCursor, bytes.NewBuffer(b), nil) - if err != nil { - return nil, err - } - if rs != nil { - rs.Close() - return nil, errors.New("response contains unexpected stream") - } - var res pdu.ReplicationCursorRes - if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - return nil, err - } - return &res, nil -} - -// Handler implements the server-side streamrpc.HandlerFunc for a Remote endpoint stub. -type Handler struct { - ep replication.Endpoint -} - -func NewHandler(ep replication.Endpoint) Handler { - return Handler{ep} -} - -func (a *Handler) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) { - - switch endpoint { - case RPCListFilesystems: - var req pdu.ListFilesystemReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - fsses, err := a.ep.ListFilesystems(ctx) - if err != nil { - return nil, nil, err - } - res := &pdu.ListFilesystemRes{ - Filesystems: fsses, - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCListFilesystemVersions: - - var req pdu.ListFilesystemVersionsReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem) - if err != nil { - return nil, nil, err - } - res := &pdu.ListFilesystemVersionsRes{ - Versions: fsvs, - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCSend: - - sender, ok := a.ep.(replication.Sender) - if !ok { - goto Err - } - - var req pdu.SendReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - res, sendStream, err := sender.Send(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), sendStream, err - - case RPCReceive: - - receiver, ok := a.ep.(replication.Receiver) - if !ok { - goto Err - } - - var req pdu.ReceiveReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - err := receiver.Receive(ctx, &req, reqStream) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(&pdu.ReceiveRes{}) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, err - - case RPCSDestroySnapshots: - - var req pdu.DestroySnapshotsReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - - res, err := a.ep.DestroySnapshots(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - case RPCReplicationCursor: - - sender, ok := a.ep.(replication.Sender) - if !ok { - goto Err - } - - var req pdu.ReplicationCursorReq - if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { - return nil, nil, err - } - res, err := sender.ReplicationCursor(ctx, &req) - if err != nil { - return nil, nil, err - } - b, err := proto.Marshal(res) - if err != nil { - return nil, nil, err - } - return bytes.NewBuffer(b), nil, nil - - } -Err: - return nil, nil, errors.New("no handler for given endpoint") -} diff --git a/replication/fsrep/fsfsm.go b/replication/fsrep/fsfsm.go index 6265d01..cfc8a5e 100644 --- a/replication/fsrep/fsfsm.go +++ b/replication/fsrep/fsfsm.go @@ -6,16 +6,16 @@ import ( "context" "errors" "fmt" - "github.com/prometheus/client_golang/prometheus" - "github.com/zrepl/zrepl/util/watchdog" - "io" "net" "sync" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/replication/pdu" - "github.com/zrepl/zrepl/util" + "github.com/zrepl/zrepl/util/bytecounter" + "github.com/zrepl/zrepl/util/watchdog" + "github.com/zrepl/zrepl/zfs" ) type contextKey int @@ -43,7 +43,7 @@ type Sender interface { // If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before // any next call to the parent github.com/zrepl/zrepl/replication.Endpoint. // If the send request is for dry run the io.ReadCloser will be nil - Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) } @@ -51,9 +51,7 @@ type Sender interface { type Receiver interface { // Receive sends r and sendStream (the latter containing a ZFS send stream) // to the parent github.com/zrepl/zrepl/replication.Endpoint. - // Implementors must guarantee that Close was called on sendStream before - // the call to Receive returns. - Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error + Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) } type StepReport struct { @@ -227,7 +225,7 @@ type ReplicationStep struct { // both retry and permanent error err error - byteCounter *util.ByteCounterReader + byteCounter bytecounter.StreamCopier expectedSize int64 // 0 means no size estimate present / possible } @@ -401,37 +399,54 @@ func (s *ReplicationStep) doReplication(ctx context.Context, ka *watchdog.KeepAl sr := s.buildSendRequest(false) log.Debug("initiate send request") - sres, sstream, err := sender.Send(ctx, sr) + sres, sstreamCopier, err := sender.Send(ctx, sr) if err != nil { log.WithError(err).Error("send request failed") return err } - if sstream == nil { + if sstreamCopier == nil { err := errors.New("send request did not return a stream, broken endpoint implementation") return err } + defer sstreamCopier.Close() - s.byteCounter = util.NewByteCounterReader(sstream) - s.byteCounter.SetCallback(1*time.Second, func(i int64) { - ka.MadeProgress() - }) - defer func() { - s.parent.promBytesReplicated.Add(float64(s.byteCounter.Bytes())) + // Install a byte counter to track progress + for status report + s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier) + byteCounterStopProgress := make(chan struct{}) + defer close(byteCounterStopProgress) + go func() { + var lastCount int64 + t := time.NewTicker(1 * time.Second) + defer t.Stop() + for { + select { + case <-byteCounterStopProgress: + return + case <-t.C: + newCount := s.byteCounter.Count() + if lastCount != newCount { + ka.MadeProgress() + } else { + lastCount = newCount + } + } + } + }() + defer func() { + s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count())) }() - sstream = s.byteCounter rr := &pdu.ReceiveReq{ Filesystem: fs, ClearResumeToken: !sres.UsedResumeToken, } log.Debug("initiate receive request") - err = receiver.Receive(ctx, rr, sstream) + _, err = receiver.Receive(ctx, rr, s.byteCounter) if err != nil { log. WithError(err). WithField("errType", fmt.Sprintf("%T", err)). Error("receive request failed (might also be error on sender)") - sstream.Close() // This failure could be due to // - an unexpected exit of ZFS on the sending side // - an unexpected exit of ZFS on the receiving side @@ -524,7 +539,7 @@ func (s *ReplicationStep) Report() *StepReport { } bytes := int64(0) if s.byteCounter != nil { - bytes = s.byteCounter.Bytes() + bytes = s.byteCounter.Count() } problem := "" if s.err != nil { diff --git a/replication/mainfsm.go b/replication/mainfsm.go index 2c8de45..5cf1d7b 100644 --- a/replication/mainfsm.go +++ b/replication/mainfsm.go @@ -10,7 +10,6 @@ import ( "github.com/zrepl/zrepl/daemon/job/wakeup" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/watchdog" - "github.com/problame/go-streamrpc" "math/bits" "net" "sort" @@ -106,9 +105,8 @@ func NewReplication(secsPerState *prometheus.HistogramVec, bytesReplicated *prom // named interfaces defined in this package. type Endpoint interface { // Does not include placeholder filesystems - ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) - // FIXME document FilteredError handling - ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS + ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) } @@ -203,7 +201,6 @@ type Error interface { var _ Error = fsrep.Error(nil) var _ Error = net.Error(nil) -var _ Error = streamrpc.Error(nil) func isPermanent(err error) bool { if e, ok := err.(Error); ok { @@ -232,19 +229,20 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r }).rsf() } - sfss, err := sender.ListFilesystems(ctx) + slfssres, err := sender.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { - log.WithError(err).Error("error listing sender filesystems") + log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing sender filesystems") return handlePlanningError(err) } + sfss := slfssres.GetFilesystems() // no progress here since we could run in a live-lock on connectivity issues - rfss, err := receiver.ListFilesystems(ctx) + rlfssres, err := receiver.ListFilesystems(ctx, &pdu.ListFilesystemReq{}) if err != nil { - log.WithError(err).Error("error listing receiver filesystems") + log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing receiver filesystems") return handlePlanningError(err) } - + rfss := rlfssres.GetFilesystems() ka.MadeProgress() // for both sender and receiver q := make([]*fsrep.Replication, 0, len(sfss)) @@ -255,11 +253,12 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r log.Debug("assessing filesystem") - sfsvs, err := sender.ListFilesystemVersions(ctx, fs.Path) + sfsvsres, err := sender.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path}) if err != nil { log.WithError(err).Error("cannot get remote filesystem versions") return handlePlanningError(err) } + sfsvs := sfsvsres.GetVersions() ka.MadeProgress() if len(sfsvs) < 1 { @@ -278,7 +277,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r var rfsvs []*pdu.FilesystemVersion if receiverFSExists { - rfsvs, err = receiver.ListFilesystemVersions(ctx, fs.Path) + rfsvsres, err := receiver.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path}) if err != nil { if _, ok := err.(*FilteredError); ok { log.Info("receiver ignores filesystem") @@ -287,6 +286,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r log.WithError(err).Error("receiver error") return handlePlanningError(err) } + rfsvs = rfsvsres.GetVersions() } else { rfsvs = []*pdu.FilesystemVersion{} } diff --git a/replication/pdu/pdu.pb.go b/replication/pdu/pdu.pb.go index f2b7d37..6b7fd86 100644 --- a/replication/pdu/pdu.pb.go +++ b/replication/pdu/pdu.pb.go @@ -7,6 +7,11 @@ import proto "github.com/golang/protobuf/proto" import fmt "fmt" import math "math" +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf @@ -38,7 +43,7 @@ func (x FilesystemVersion_VersionType) String() string { return proto.EnumName(FilesystemVersion_VersionType_name, int32(x)) } func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5, 0} + return fileDescriptor_pdu_89315d819a6e0938, []int{5, 0} } type ListFilesystemReq struct { @@ -51,7 +56,7 @@ func (m *ListFilesystemReq) Reset() { *m = ListFilesystemReq{} } func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemReq) ProtoMessage() {} func (*ListFilesystemReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{0} + return fileDescriptor_pdu_89315d819a6e0938, []int{0} } func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b) @@ -73,6 +78,7 @@ var xxx_messageInfo_ListFilesystemReq proto.InternalMessageInfo type ListFilesystemRes struct { Filesystems []*Filesystem `protobuf:"bytes,1,rep,name=Filesystems,proto3" json:"Filesystems,omitempty"` + Empty bool `protobuf:"varint,2,opt,name=Empty,proto3" json:"Empty,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -82,7 +88,7 @@ func (m *ListFilesystemRes) Reset() { *m = ListFilesystemRes{} } func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemRes) ProtoMessage() {} func (*ListFilesystemRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{1} + return fileDescriptor_pdu_89315d819a6e0938, []int{1} } func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b) @@ -109,6 +115,13 @@ func (m *ListFilesystemRes) GetFilesystems() []*Filesystem { return nil } +func (m *ListFilesystemRes) GetEmpty() bool { + if m != nil { + return m.Empty + } + return false +} + type Filesystem struct { Path string `protobuf:"bytes,1,opt,name=Path,proto3" json:"Path,omitempty"` ResumeToken string `protobuf:"bytes,2,opt,name=ResumeToken,proto3" json:"ResumeToken,omitempty"` @@ -121,7 +134,7 @@ func (m *Filesystem) Reset() { *m = Filesystem{} } func (m *Filesystem) String() string { return proto.CompactTextString(m) } func (*Filesystem) ProtoMessage() {} func (*Filesystem) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{2} + return fileDescriptor_pdu_89315d819a6e0938, []int{2} } func (m *Filesystem) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Filesystem.Unmarshal(m, b) @@ -166,7 +179,7 @@ func (m *ListFilesystemVersionsReq) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsReq) ProtoMessage() {} func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{3} + return fileDescriptor_pdu_89315d819a6e0938, []int{3} } func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b) @@ -204,7 +217,7 @@ func (m *ListFilesystemVersionsRes) Reset() { *m = ListFilesystemVersion func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) } func (*ListFilesystemVersionsRes) ProtoMessage() {} func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{4} + return fileDescriptor_pdu_89315d819a6e0938, []int{4} } func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b) @@ -232,7 +245,7 @@ func (m *ListFilesystemVersionsRes) GetVersions() []*FilesystemVersion { } type FilesystemVersion struct { - Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=pdu.FilesystemVersion_VersionType" json:"Type,omitempty"` + Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=FilesystemVersion_VersionType" json:"Type,omitempty"` Name string `protobuf:"bytes,2,opt,name=Name,proto3" json:"Name,omitempty"` Guid uint64 `protobuf:"varint,3,opt,name=Guid,proto3" json:"Guid,omitempty"` CreateTXG uint64 `protobuf:"varint,4,opt,name=CreateTXG,proto3" json:"CreateTXG,omitempty"` @@ -246,7 +259,7 @@ func (m *FilesystemVersion) Reset() { *m = FilesystemVersion{} } func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) } func (*FilesystemVersion) ProtoMessage() {} func (*FilesystemVersion) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5} + return fileDescriptor_pdu_89315d819a6e0938, []int{5} } func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b) @@ -326,7 +339,7 @@ func (m *SendReq) Reset() { *m = SendReq{} } func (m *SendReq) String() string { return proto.CompactTextString(m) } func (*SendReq) ProtoMessage() {} func (*SendReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{6} + return fileDescriptor_pdu_89315d819a6e0938, []int{6} } func (m *SendReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendReq.Unmarshal(m, b) @@ -407,7 +420,7 @@ func (m *Property) Reset() { *m = Property{} } func (m *Property) String() string { return proto.CompactTextString(m) } func (*Property) ProtoMessage() {} func (*Property) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{7} + return fileDescriptor_pdu_89315d819a6e0938, []int{7} } func (m *Property) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Property.Unmarshal(m, b) @@ -443,11 +456,11 @@ func (m *Property) GetValue() string { type SendRes struct { // Whether the resume token provided in the request has been used or not. - UsedResumeToken bool `protobuf:"varint,1,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"` + UsedResumeToken bool `protobuf:"varint,2,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"` // Expected stream size determined by dry run, not exact. // 0 indicates that for the given SendReq, no size estimate could be made. - ExpectedSize int64 `protobuf:"varint,2,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"` - Properties []*Property `protobuf:"bytes,3,rep,name=Properties,proto3" json:"Properties,omitempty"` + ExpectedSize int64 `protobuf:"varint,3,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"` + Properties []*Property `protobuf:"bytes,4,rep,name=Properties,proto3" json:"Properties,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -457,7 +470,7 @@ func (m *SendRes) Reset() { *m = SendRes{} } func (m *SendRes) String() string { return proto.CompactTextString(m) } func (*SendRes) ProtoMessage() {} func (*SendRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{8} + return fileDescriptor_pdu_89315d819a6e0938, []int{8} } func (m *SendRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_SendRes.Unmarshal(m, b) @@ -511,7 +524,7 @@ func (m *ReceiveReq) Reset() { *m = ReceiveReq{} } func (m *ReceiveReq) String() string { return proto.CompactTextString(m) } func (*ReceiveReq) ProtoMessage() {} func (*ReceiveReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{9} + return fileDescriptor_pdu_89315d819a6e0938, []int{9} } func (m *ReceiveReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveReq.Unmarshal(m, b) @@ -555,7 +568,7 @@ func (m *ReceiveRes) Reset() { *m = ReceiveRes{} } func (m *ReceiveRes) String() string { return proto.CompactTextString(m) } func (*ReceiveRes) ProtoMessage() {} func (*ReceiveRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{10} + return fileDescriptor_pdu_89315d819a6e0938, []int{10} } func (m *ReceiveRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReceiveRes.Unmarshal(m, b) @@ -588,7 +601,7 @@ func (m *DestroySnapshotsReq) Reset() { *m = DestroySnapshotsReq{} } func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsReq) ProtoMessage() {} func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{11} + return fileDescriptor_pdu_89315d819a6e0938, []int{11} } func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b) @@ -634,7 +647,7 @@ func (m *DestroySnapshotRes) Reset() { *m = DestroySnapshotRes{} } func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotRes) ProtoMessage() {} func (*DestroySnapshotRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{12} + return fileDescriptor_pdu_89315d819a6e0938, []int{12} } func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b) @@ -679,7 +692,7 @@ func (m *DestroySnapshotsRes) Reset() { *m = DestroySnapshotsRes{} } func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) } func (*DestroySnapshotsRes) ProtoMessage() {} func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{13} + return fileDescriptor_pdu_89315d819a6e0938, []int{13} } func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b) @@ -721,7 +734,7 @@ func (m *ReplicationCursorReq) Reset() { *m = ReplicationCursorReq{} } func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq) ProtoMessage() {} func (*ReplicationCursorReq) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14} + return fileDescriptor_pdu_89315d819a6e0938, []int{14} } func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b) @@ -869,7 +882,7 @@ func (m *ReplicationCursorReq_GetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_GetOp) ProtoMessage() {} func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 0} + return fileDescriptor_pdu_89315d819a6e0938, []int{14, 0} } func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b) @@ -900,7 +913,7 @@ func (m *ReplicationCursorReq_SetOp) Reset() { *m = ReplicationCursorReq func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorReq_SetOp) ProtoMessage() {} func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 1} + return fileDescriptor_pdu_89315d819a6e0938, []int{14, 1} } func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b) @@ -941,7 +954,7 @@ func (m *ReplicationCursorRes) Reset() { *m = ReplicationCursorRes{} } func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) } func (*ReplicationCursorRes) ProtoMessage() {} func (*ReplicationCursorRes) Descriptor() ([]byte, []int) { - return fileDescriptor_pdu_fe566e6b212fcf8d, []int{15} + return fileDescriptor_pdu_89315d819a6e0938, []int{15} } func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b) @@ -1067,71 +1080,246 @@ func _ReplicationCursorRes_OneofSizer(msg proto.Message) (n int) { } func init() { - proto.RegisterType((*ListFilesystemReq)(nil), "pdu.ListFilesystemReq") - proto.RegisterType((*ListFilesystemRes)(nil), "pdu.ListFilesystemRes") - proto.RegisterType((*Filesystem)(nil), "pdu.Filesystem") - proto.RegisterType((*ListFilesystemVersionsReq)(nil), "pdu.ListFilesystemVersionsReq") - proto.RegisterType((*ListFilesystemVersionsRes)(nil), "pdu.ListFilesystemVersionsRes") - proto.RegisterType((*FilesystemVersion)(nil), "pdu.FilesystemVersion") - proto.RegisterType((*SendReq)(nil), "pdu.SendReq") - proto.RegisterType((*Property)(nil), "pdu.Property") - proto.RegisterType((*SendRes)(nil), "pdu.SendRes") - proto.RegisterType((*ReceiveReq)(nil), "pdu.ReceiveReq") - proto.RegisterType((*ReceiveRes)(nil), "pdu.ReceiveRes") - proto.RegisterType((*DestroySnapshotsReq)(nil), "pdu.DestroySnapshotsReq") - proto.RegisterType((*DestroySnapshotRes)(nil), "pdu.DestroySnapshotRes") - proto.RegisterType((*DestroySnapshotsRes)(nil), "pdu.DestroySnapshotsRes") - proto.RegisterType((*ReplicationCursorReq)(nil), "pdu.ReplicationCursorReq") - proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "pdu.ReplicationCursorReq.GetOp") - proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "pdu.ReplicationCursorReq.SetOp") - proto.RegisterType((*ReplicationCursorRes)(nil), "pdu.ReplicationCursorRes") - proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) + proto.RegisterType((*ListFilesystemReq)(nil), "ListFilesystemReq") + proto.RegisterType((*ListFilesystemRes)(nil), "ListFilesystemRes") + proto.RegisterType((*Filesystem)(nil), "Filesystem") + proto.RegisterType((*ListFilesystemVersionsReq)(nil), "ListFilesystemVersionsReq") + proto.RegisterType((*ListFilesystemVersionsRes)(nil), "ListFilesystemVersionsRes") + proto.RegisterType((*FilesystemVersion)(nil), "FilesystemVersion") + proto.RegisterType((*SendReq)(nil), "SendReq") + proto.RegisterType((*Property)(nil), "Property") + proto.RegisterType((*SendRes)(nil), "SendRes") + proto.RegisterType((*ReceiveReq)(nil), "ReceiveReq") + proto.RegisterType((*ReceiveRes)(nil), "ReceiveRes") + proto.RegisterType((*DestroySnapshotsReq)(nil), "DestroySnapshotsReq") + proto.RegisterType((*DestroySnapshotRes)(nil), "DestroySnapshotRes") + proto.RegisterType((*DestroySnapshotsRes)(nil), "DestroySnapshotsRes") + proto.RegisterType((*ReplicationCursorReq)(nil), "ReplicationCursorReq") + proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "ReplicationCursorReq.GetOp") + proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "ReplicationCursorReq.SetOp") + proto.RegisterType((*ReplicationCursorRes)(nil), "ReplicationCursorRes") + proto.RegisterEnum("FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value) } -func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_fe566e6b212fcf8d) } +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn -var fileDescriptor_pdu_fe566e6b212fcf8d = []byte{ - // 659 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdb, 0x6e, 0x13, 0x31, - 0x10, 0xcd, 0xe6, 0xba, 0x99, 0x94, 0x5e, 0xdc, 0xaa, 0x2c, 0x15, 0x82, 0xc8, 0xbc, 0x04, 0x24, - 0x22, 0x91, 0x56, 0xbc, 0xf0, 0x96, 0xde, 0xf2, 0x80, 0xda, 0xca, 0x09, 0x55, 0x9f, 0x90, 0x42, - 0x77, 0x44, 0x57, 0xb9, 0x78, 0x6b, 0x7b, 0x51, 0xc3, 0x07, 0xf0, 0x4f, 0xfc, 0x07, 0x0f, 0x7c, - 0x0e, 0xf2, 0xec, 0x25, 0xdb, 0x24, 0x54, 0x79, 0x8a, 0xcf, 0xf8, 0x78, 0xe6, 0xcc, 0xf1, 0x8e, - 0x03, 0xf5, 0xd0, 0x8f, 0xda, 0xa1, 0x92, 0x46, 0xb2, 0x52, 0xe8, 0x47, 0x7c, 0x17, 0x76, 0x3e, - 0x07, 0xda, 0x9c, 0x05, 0x63, 0xd4, 0x33, 0x6d, 0x70, 0x22, 0xf0, 0x9e, 0x9f, 0x2d, 0x07, 0x35, - 0xfb, 0x00, 0x8d, 0x79, 0x40, 0x7b, 0x4e, 0xb3, 0xd4, 0x6a, 0x74, 0xb6, 0xda, 0x36, 0x5f, 0x8e, - 0x98, 0xe7, 0xf0, 0x2e, 0xc0, 0x1c, 0x32, 0x06, 0xe5, 0xab, 0xa1, 0xb9, 0xf3, 0x9c, 0xa6, 0xd3, - 0xaa, 0x0b, 0x5a, 0xb3, 0x26, 0x34, 0x04, 0xea, 0x68, 0x82, 0x03, 0x39, 0xc2, 0xa9, 0x57, 0xa4, - 0xad, 0x7c, 0x88, 0x7f, 0x82, 0x17, 0x8f, 0xb5, 0x5c, 0xa3, 0xd2, 0x81, 0x9c, 0x6a, 0x81, 0xf7, - 0xec, 0x55, 0xbe, 0x40, 0x92, 0x38, 0x17, 0xe1, 0x97, 0xff, 0x3f, 0xac, 0x59, 0x07, 0xdc, 0x14, - 0x26, 0xdd, 0xec, 0x2f, 0x74, 0x93, 0x6c, 0x8b, 0x8c, 0xc7, 0xff, 0x3a, 0xb0, 0xb3, 0xb4, 0xcf, - 0x3e, 0x42, 0x79, 0x30, 0x0b, 0x91, 0x04, 0x6c, 0x76, 0xf8, 0xea, 0x2c, 0xed, 0xe4, 0xd7, 0x32, - 0x05, 0xf1, 0xad, 0x23, 0x17, 0xc3, 0x09, 0x26, 0x6d, 0xd3, 0xda, 0xc6, 0xce, 0xa3, 0xc0, 0xf7, - 0x4a, 0x4d, 0xa7, 0x55, 0x16, 0xb4, 0x66, 0x2f, 0xa1, 0x7e, 0xac, 0x70, 0x68, 0x70, 0x70, 0x73, - 0xee, 0x95, 0x69, 0x63, 0x1e, 0x60, 0x07, 0xe0, 0x12, 0x08, 0xe4, 0xd4, 0xab, 0x50, 0xa6, 0x0c, - 0xf3, 0xb7, 0xd0, 0xc8, 0x95, 0x65, 0x1b, 0xe0, 0xf6, 0xa7, 0xc3, 0x50, 0xdf, 0x49, 0xb3, 0x5d, - 0xb0, 0xa8, 0x2b, 0xe5, 0x68, 0x32, 0x54, 0xa3, 0x6d, 0x87, 0xff, 0x76, 0xa0, 0xd6, 0xc7, 0xa9, - 0xbf, 0x86, 0xaf, 0x56, 0xe4, 0x99, 0x92, 0x93, 0x54, 0xb8, 0x5d, 0xb3, 0x4d, 0x28, 0x0e, 0x24, - 0xc9, 0xae, 0x8b, 0xe2, 0x40, 0x2e, 0x5e, 0x6d, 0x79, 0xe9, 0x6a, 0x49, 0xb8, 0x9c, 0x84, 0x0a, - 0xb5, 0x26, 0xe1, 0xae, 0xc8, 0x30, 0xdb, 0x83, 0xca, 0x09, 0xfa, 0x51, 0xe8, 0x55, 0x69, 0x23, - 0x06, 0x6c, 0x1f, 0xaa, 0x27, 0x6a, 0x26, 0xa2, 0xa9, 0x57, 0xa3, 0x70, 0x82, 0xf8, 0x11, 0xb8, - 0x57, 0x4a, 0x86, 0xa8, 0xcc, 0x2c, 0x33, 0xd5, 0xc9, 0x99, 0xba, 0x07, 0x95, 0xeb, 0xe1, 0x38, - 0x4a, 0x9d, 0x8e, 0x01, 0xff, 0x95, 0x75, 0xac, 0x59, 0x0b, 0xb6, 0xbe, 0x68, 0xf4, 0xf3, 0x8a, - 0x1d, 0x2a, 0xb1, 0x18, 0x66, 0x1c, 0x36, 0x4e, 0x1f, 0x42, 0xbc, 0x35, 0xe8, 0xf7, 0x83, 0x9f, - 0x71, 0xca, 0x92, 0x78, 0x14, 0x63, 0xef, 0x01, 0x12, 0x3d, 0x01, 0x6a, 0xaf, 0x44, 0x1f, 0xd7, - 0x33, 0xfa, 0x2c, 0x52, 0x99, 0x22, 0x47, 0xe0, 0x37, 0x00, 0x02, 0x6f, 0x31, 0xf8, 0x81, 0xeb, - 0x98, 0xff, 0x0e, 0xb6, 0x8f, 0xc7, 0x38, 0x54, 0x8b, 0x83, 0xe3, 0x8a, 0xa5, 0x38, 0xdf, 0xc8, - 0x65, 0xd6, 0x7c, 0x04, 0xbb, 0x27, 0xa8, 0x8d, 0x92, 0xb3, 0xf4, 0x2b, 0x58, 0x67, 0x8a, 0xd8, - 0x11, 0xd4, 0x33, 0xbe, 0x57, 0x7c, 0x72, 0x52, 0xe6, 0x44, 0xfe, 0x15, 0xd8, 0x42, 0xb1, 0x64, - 0xe8, 0x52, 0x48, 0x95, 0x9e, 0x18, 0xba, 0x94, 0x67, 0x6f, 0xef, 0x54, 0x29, 0xa9, 0xd2, 0xdb, - 0x23, 0xc0, 0x7b, 0xab, 0x9a, 0xb1, 0xcf, 0x54, 0xcd, 0x1a, 0x30, 0x36, 0xe9, 0x50, 0x3f, 0xa7, - 0xfc, 0xcb, 0x52, 0x44, 0xca, 0xe3, 0x7f, 0x1c, 0xd8, 0x13, 0x18, 0x8e, 0x83, 0x5b, 0x1a, 0x9a, - 0xe3, 0x48, 0x69, 0xa9, 0xd6, 0x31, 0xe6, 0x10, 0x4a, 0xdf, 0xd1, 0x90, 0xac, 0x46, 0xe7, 0x35, - 0xd5, 0x59, 0x95, 0xa7, 0x7d, 0x8e, 0xe6, 0x32, 0xec, 0x15, 0x84, 0x65, 0xdb, 0x43, 0x1a, 0x0d, - 0x0d, 0xca, 0x93, 0x87, 0xfa, 0xe9, 0x21, 0x8d, 0xe6, 0xa0, 0x06, 0x15, 0x4a, 0x72, 0xf0, 0x06, - 0x2a, 0xb4, 0x61, 0x87, 0x27, 0x33, 0x32, 0xf6, 0x25, 0xc3, 0xdd, 0x32, 0x14, 0x65, 0xc8, 0x07, - 0x2b, 0xbb, 0xb2, 0xa3, 0x15, 0xbf, 0x30, 0xb6, 0x9f, 0x72, 0xaf, 0x90, 0xbd, 0x31, 0xee, 0x85, - 0x34, 0xf8, 0x10, 0xe8, 0x38, 0x9f, 0xdb, 0x2b, 0x88, 0x2c, 0xd2, 0x75, 0xa1, 0x1a, 0xbb, 0xf5, - 0xad, 0x4a, 0x7f, 0x1e, 0x87, 0xff, 0x02, 0x00, 0x00, 0xff, 0xff, 0x66, 0x74, 0x36, 0x3a, 0x49, - 0x06, 0x00, 0x00, +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// ReplicationClient is the client API for Replication service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type ReplicationClient interface { + ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) + ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) + DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) + ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) +} + +type replicationClient struct { + cc *grpc.ClientConn +} + +func NewReplicationClient(cc *grpc.ClientConn) ReplicationClient { + return &replicationClient{cc} +} + +func (c *replicationClient) ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) { + out := new(ListFilesystemRes) + err := c.cc.Invoke(ctx, "/Replication/ListFilesystems", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) { + out := new(ListFilesystemVersionsRes) + err := c.cc.Invoke(ctx, "/Replication/ListFilesystemVersions", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) { + out := new(DestroySnapshotsRes) + err := c.cc.Invoke(ctx, "/Replication/DestroySnapshots", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *replicationClient) ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) { + out := new(ReplicationCursorRes) + err := c.cc.Invoke(ctx, "/Replication/ReplicationCursor", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ReplicationServer is the server API for Replication service. +type ReplicationServer interface { + ListFilesystems(context.Context, *ListFilesystemReq) (*ListFilesystemRes, error) + ListFilesystemVersions(context.Context, *ListFilesystemVersionsReq) (*ListFilesystemVersionsRes, error) + DestroySnapshots(context.Context, *DestroySnapshotsReq) (*DestroySnapshotsRes, error) + ReplicationCursor(context.Context, *ReplicationCursorReq) (*ReplicationCursorRes, error) +} + +func RegisterReplicationServer(s *grpc.Server, srv ReplicationServer) { + s.RegisterService(&_Replication_serviceDesc, srv) +} + +func _Replication_ListFilesystems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListFilesystemReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ListFilesystems(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ListFilesystems", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ListFilesystems(ctx, req.(*ListFilesystemReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_ListFilesystemVersions_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListFilesystemVersionsReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ListFilesystemVersions(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ListFilesystemVersions", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ListFilesystemVersions(ctx, req.(*ListFilesystemVersionsReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_DestroySnapshots_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DestroySnapshotsReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).DestroySnapshots(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/DestroySnapshots", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).DestroySnapshots(ctx, req.(*DestroySnapshotsReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _Replication_ReplicationCursor_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReplicationCursorReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReplicationServer).ReplicationCursor(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/Replication/ReplicationCursor", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReplicationServer).ReplicationCursor(ctx, req.(*ReplicationCursorReq)) + } + return interceptor(ctx, in, info, handler) +} + +var _Replication_serviceDesc = grpc.ServiceDesc{ + ServiceName: "Replication", + HandlerType: (*ReplicationServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ListFilesystems", + Handler: _Replication_ListFilesystems_Handler, + }, + { + MethodName: "ListFilesystemVersions", + Handler: _Replication_ListFilesystemVersions_Handler, + }, + { + MethodName: "DestroySnapshots", + Handler: _Replication_DestroySnapshots_Handler, + }, + { + MethodName: "ReplicationCursor", + Handler: _Replication_ReplicationCursor_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "pdu.proto", +} + +func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_89315d819a6e0938) } + +var fileDescriptor_pdu_89315d819a6e0938 = []byte{ + // 735 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdd, 0x6e, 0xda, 0x4a, + 0x10, 0xc6, 0x60, 0xc0, 0x0c, 0x51, 0x42, 0x36, 0x9c, 0xc8, 0xc7, 0xe7, 0x28, 0x42, 0xdb, 0x1b, + 0x52, 0xa9, 0x6e, 0x45, 0x7b, 0x53, 0x55, 0xaa, 0x54, 0x42, 0x7e, 0xa4, 0x56, 0x69, 0xb4, 0xd0, + 0x28, 0xca, 0x1d, 0x0d, 0xa3, 0xc4, 0x0a, 0xb0, 0xce, 0xee, 0xba, 0x0a, 0xbd, 0xec, 0x7b, 0xf4, + 0x41, 0xfa, 0x0e, 0xbd, 0xec, 0x03, 0x55, 0xbb, 0x60, 0xe3, 0x60, 0x23, 0x71, 0xe5, 0xfd, 0xbe, + 0x9d, 0x9d, 0x9d, 0xf9, 0x76, 0x66, 0x0c, 0xb5, 0x70, 0x14, 0xf9, 0xa1, 0xe0, 0x8a, 0xd3, 0x3d, + 0xd8, 0xfd, 0x14, 0x48, 0x75, 0x12, 0x8c, 0x51, 0xce, 0xa4, 0xc2, 0x09, 0xc3, 0x07, 0x7a, 0x95, + 0x25, 0x25, 0x79, 0x01, 0xf5, 0x25, 0x21, 0x5d, 0xab, 0x55, 0x6a, 0xd7, 0x3b, 0x75, 0x3f, 0x65, + 0x94, 0xde, 0x27, 0x4d, 0x28, 0x1f, 0x4f, 0x42, 0x35, 0x73, 0x8b, 0x2d, 0xab, 0xed, 0xb0, 0x39, + 0xa0, 0x5d, 0x80, 0xa5, 0x11, 0x21, 0x60, 0x5f, 0x0c, 0xd5, 0x9d, 0x6b, 0xb5, 0xac, 0x76, 0x8d, + 0x99, 0x35, 0x69, 0x41, 0x9d, 0xa1, 0x8c, 0x26, 0x38, 0xe0, 0xf7, 0x38, 0x35, 0xa7, 0x6b, 0x2c, + 0x4d, 0xd1, 0x77, 0xf0, 0xef, 0xd3, 0xe8, 0x2e, 0x51, 0xc8, 0x80, 0x4f, 0x25, 0xc3, 0x07, 0x72, + 0x90, 0xbe, 0x60, 0xe1, 0x38, 0xc5, 0xd0, 0x8f, 0xeb, 0x0f, 0x4b, 0xe2, 0x83, 0x13, 0xc3, 0x45, + 0x7e, 0xc4, 0xcf, 0x58, 0xb2, 0xc4, 0x86, 0xfe, 0xb1, 0x60, 0x37, 0xb3, 0x4f, 0x3a, 0x60, 0x0f, + 0x66, 0x21, 0x9a, 0xcb, 0xb7, 0x3b, 0x07, 0x59, 0x0f, 0xfe, 0xe2, 0xab, 0xad, 0x98, 0xb1, 0xd5, + 0x4a, 0x9c, 0x0f, 0x27, 0xb8, 0x48, 0xd7, 0xac, 0x35, 0x77, 0x1a, 0x05, 0x23, 0xb7, 0xd4, 0xb2, + 0xda, 0x36, 0x33, 0x6b, 0xf2, 0x3f, 0xd4, 0x8e, 0x04, 0x0e, 0x15, 0x0e, 0xae, 0x4e, 0x5d, 0xdb, + 0x6c, 0x2c, 0x09, 0xe2, 0x81, 0x63, 0x40, 0xc0, 0xa7, 0x6e, 0xd9, 0x78, 0x4a, 0x30, 0x3d, 0x84, + 0x7a, 0xea, 0x5a, 0xb2, 0x05, 0x4e, 0x7f, 0x3a, 0x0c, 0xe5, 0x1d, 0x57, 0x8d, 0x82, 0x46, 0x5d, + 0xce, 0xef, 0x27, 0x43, 0x71, 0xdf, 0xb0, 0xe8, 0x2f, 0x0b, 0xaa, 0x7d, 0x9c, 0x8e, 0x36, 0xd0, + 0x53, 0x07, 0x79, 0x22, 0xf8, 0x24, 0x0e, 0x5c, 0xaf, 0xc9, 0x36, 0x14, 0x07, 0xdc, 0x84, 0x5d, + 0x63, 0xc5, 0x01, 0x5f, 0x7d, 0x52, 0x3b, 0xf3, 0xa4, 0x26, 0x70, 0x3e, 0x09, 0x05, 0x4a, 0x69, + 0x02, 0x77, 0x58, 0x82, 0x75, 0x21, 0xf5, 0x70, 0x14, 0x85, 0x6e, 0x65, 0x5e, 0x48, 0x06, 0x90, + 0x7d, 0xa8, 0xf4, 0xc4, 0x8c, 0x45, 0x53, 0xb7, 0x6a, 0xe8, 0x05, 0xa2, 0x6f, 0xc0, 0xb9, 0x10, + 0x3c, 0x44, 0xa1, 0x66, 0x89, 0xa8, 0x56, 0x4a, 0xd4, 0x26, 0x94, 0x2f, 0x87, 0xe3, 0x28, 0x56, + 0x7a, 0x0e, 0xe8, 0x8f, 0x24, 0x63, 0x49, 0xda, 0xb0, 0xf3, 0x45, 0xe2, 0x68, 0xb5, 0x08, 0x1d, + 0xb6, 0x4a, 0x13, 0x0a, 0x5b, 0xc7, 0x8f, 0x21, 0xde, 0x28, 0x1c, 0xf5, 0x83, 0xef, 0x68, 0x32, + 0x2e, 0xb1, 0x27, 0x1c, 0x39, 0x04, 0x58, 0xc4, 0x13, 0xa0, 0x74, 0x6d, 0x53, 0x54, 0x35, 0x3f, + 0x0e, 0x91, 0xa5, 0x36, 0xe9, 0x15, 0x00, 0xc3, 0x1b, 0x0c, 0xbe, 0xe1, 0x26, 0xc2, 0x3f, 0x87, + 0xc6, 0xd1, 0x18, 0x87, 0x22, 0x1b, 0x67, 0x86, 0xa7, 0x5b, 0x29, 0xcf, 0x92, 0xde, 0xc2, 0x5e, + 0x0f, 0xa5, 0x12, 0x7c, 0x16, 0x57, 0xc0, 0x26, 0x9d, 0x43, 0x5e, 0x41, 0x2d, 0xb1, 0x77, 0x8b, + 0x6b, 0xbb, 0x63, 0x69, 0x44, 0xaf, 0x81, 0xac, 0x5c, 0xb4, 0x68, 0xb2, 0x18, 0x9a, 0x5b, 0xd6, + 0x34, 0x59, 0x6c, 0x63, 0x06, 0x89, 0x10, 0x5c, 0xc4, 0x2f, 0x66, 0x00, 0xed, 0xe5, 0x25, 0xa1, + 0x87, 0x54, 0x55, 0x27, 0x3e, 0x56, 0x71, 0x03, 0xef, 0xf9, 0xd9, 0x10, 0x58, 0x6c, 0x43, 0x7f, + 0x5b, 0xd0, 0x64, 0x18, 0x8e, 0x83, 0x1b, 0xd3, 0x24, 0x47, 0x91, 0x90, 0x5c, 0x6c, 0x22, 0xc6, + 0x4b, 0x28, 0xdd, 0xa2, 0x32, 0x21, 0xd5, 0x3b, 0xff, 0xf9, 0x79, 0x3e, 0xfc, 0x53, 0x54, 0x9f, + 0xc3, 0xb3, 0x02, 0xd3, 0x96, 0xfa, 0x80, 0x44, 0x65, 0x4a, 0x64, 0xed, 0x81, 0x7e, 0x7c, 0x40, + 0xa2, 0xf2, 0xaa, 0x50, 0x36, 0x0e, 0xbc, 0x67, 0x50, 0x36, 0x1b, 0xba, 0x49, 0x12, 0xe1, 0xe6, + 0x5a, 0x24, 0xb8, 0x6b, 0x43, 0x91, 0x87, 0x74, 0x90, 0x9b, 0x8d, 0x6e, 0xa1, 0xf9, 0x24, 0xd1, + 0x79, 0xd8, 0x67, 0x85, 0x64, 0x96, 0x38, 0xe7, 0x5c, 0xe1, 0x63, 0x20, 0xe7, 0xfe, 0x9c, 0xb3, + 0x02, 0x4b, 0x98, 0xae, 0x03, 0x95, 0xb9, 0x4a, 0x9d, 0x9f, 0x45, 0xdd, 0xbf, 0x89, 0x5b, 0xf2, + 0x16, 0x76, 0x9e, 0x8e, 0x50, 0x49, 0x88, 0x9f, 0xf9, 0x89, 0x78, 0x59, 0x4e, 0x92, 0x0b, 0xd8, + 0xcf, 0x9f, 0xbe, 0xc4, 0xf3, 0xd7, 0xce, 0x74, 0x6f, 0xfd, 0x9e, 0x24, 0xef, 0xa1, 0xb1, 0x5a, + 0x07, 0xa4, 0xe9, 0xe7, 0xd4, 0xb7, 0x97, 0xc7, 0x4a, 0xf2, 0x01, 0x76, 0x33, 0x92, 0x91, 0x7f, + 0x72, 0xdf, 0xc7, 0xcb, 0xa5, 0x65, 0xb7, 0x7c, 0x5d, 0x0a, 0x47, 0xd1, 0xd7, 0x8a, 0xf9, 0xa1, + 0xbe, 0xfe, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xba, 0x8e, 0x63, 0x5d, 0x07, 0x00, 0x00, } diff --git a/replication/pdu/pdu.proto b/replication/pdu/pdu.proto index 6d9430a..1b66916 100644 --- a/replication/pdu/pdu.proto +++ b/replication/pdu/pdu.proto @@ -1,11 +1,19 @@ syntax = "proto3"; +option go_package = "pdu"; -package pdu; +service Replication { + rpc ListFilesystems (ListFilesystemReq) returns (ListFilesystemRes); + rpc ListFilesystemVersions (ListFilesystemVersionsReq) returns (ListFilesystemVersionsRes); + rpc DestroySnapshots (DestroySnapshotsReq) returns (DestroySnapshotsRes); + rpc ReplicationCursor (ReplicationCursorReq) returns (ReplicationCursorRes); + // for Send and Recv, see package rpc +} message ListFilesystemReq {} message ListFilesystemRes { repeated Filesystem Filesystems = 1; + bool Empty = 2; } message Filesystem { @@ -60,22 +68,18 @@ message Property { } message SendRes { - // The actual stream is in the stream part of the streamrpc response - // Whether the resume token provided in the request has been used or not. - bool UsedResumeToken = 1; + bool UsedResumeToken = 2; // Expected stream size determined by dry run, not exact. // 0 indicates that for the given SendReq, no size estimate could be made. - int64 ExpectedSize = 2; + int64 ExpectedSize = 3; - repeated Property Properties = 3; + repeated Property Properties = 4; } message ReceiveReq { - // The stream part of the streamrpc request contains the zfs send stream - - string Filesystem = 1; + string Filesystem = 1; // FIXME should be snapshot name, we can enforce that on recv // If true, the receiver should clear the resume token before perfoming the zfs recv of the stream in the request bool ClearResumeToken = 2; diff --git a/rpc/dataconn/base2bufpool/base2bufpool.go b/rpc/dataconn/base2bufpool/base2bufpool.go new file mode 100644 index 0000000..73efecc --- /dev/null +++ b/rpc/dataconn/base2bufpool/base2bufpool.go @@ -0,0 +1,169 @@ +package base2bufpool + +import ( + "fmt" + "math/bits" + "sync" +) + +type pool struct { + mtx sync.Mutex + bufs [][]byte + shift uint +} + +func (p *pool) Put(buf []byte) { + p.mtx.Lock() + defer p.mtx.Unlock() + if len(buf) != 1< 10 { // FIXME constant + return + } + p.bufs = append(p.bufs, buf) +} + +func (p *pool) Get() []byte { + p.mtx.Lock() + defer p.mtx.Unlock() + if len(p.bufs) > 0 { + ret := p.bufs[len(p.bufs)-1] + p.bufs = p.bufs[0 : len(p.bufs)-1] + return ret + } + return make([]byte, 1< b.payloadLen { + panic(fmt.Sprintf("shrink is actually an expand, invalid: %v %v", newPayloadLen, b.payloadLen)) + } + b.payloadLen = newPayloadLen +} + +func (b Buffer) Free() { + if b.pool != nil { + b.pool.put(b) + } +} + +//go:generate enumer -type NoFitBehavior +type NoFitBehavior uint + +const ( + AllocateSmaller NoFitBehavior = 1 << iota + AllocateLarger + + Allocate NoFitBehavior = AllocateSmaller | AllocateLarger + Panic NoFitBehavior = 0 +) + +func New(minShift, maxShift uint, noFitBehavior NoFitBehavior) *Pool { + if minShift > 63 || maxShift > 63 { + panic(fmt.Sprintf("{min|max}Shift are the _exponent_, got minShift=%v maxShift=%v and limit of 63, which amounts to %v bits", minShift, maxShift, uint64(1)<<63)) + } + pools := make([]pool, maxShift-minShift+1) + for i := uint(0); i < uint(len(pools)); i++ { + i := i // the closure below must copy i + pools[i] = pool{ + shift: minShift + i, + bufs: make([][]byte, 0, 10), + } + } + return &Pool{ + minShift: minShift, + maxShift: maxShift, + pools: pools, + onNoFit: noFitBehavior, + } +} + +func fittingShift(x uint) uint { + if x == 0 { + return 0 + } + blen := uint(bits.Len(x)) + if 1<<(blen-1) == x { + return blen - 1 + } + return blen +} + +func (p *Pool) handlePotentialNoFit(reqShift uint) (buf Buffer, didHandle bool) { + if reqShift == 0 { + if p.onNoFit&AllocateSmaller != 0 { + return Buffer{[]byte{}, 0, nil}, true + } else { + goto doPanic + } + } + if reqShift < p.minShift { + if p.onNoFit&AllocateSmaller != 0 { + goto alloc + } else { + goto doPanic + } + } + if reqShift > p.maxShift { + if p.onNoFit&AllocateLarger != 0 { + goto alloc + } else { + goto doPanic + } + } + return Buffer{}, false +alloc: + return Buffer{make([]byte, 1< 1 { + panic(fmt.Sprintf("putting buffer that is not power of two len: %v", len(buf))) + } + if len(buf) == 0 { + return + } + shift := fittingShift(uint(len(buf))) + if shift < p.minShift || shift > p.maxShift { + return // drop it + } + p.pools[shift-p.minShift].Put(buf) +} diff --git a/rpc/dataconn/base2bufpool/base2bufpool_test.go b/rpc/dataconn/base2bufpool/base2bufpool_test.go new file mode 100644 index 0000000..d6ce361 --- /dev/null +++ b/rpc/dataconn/base2bufpool/base2bufpool_test.go @@ -0,0 +1,98 @@ +package base2bufpool + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPoolAllocBehavior(t *testing.T) { + + type testcase struct { + poolMinShift, poolMaxShift uint + behavior NoFitBehavior + get uint + expShiftBufLen int64 // -1 if panic expected + } + + tcs := []testcase{ + { + 15, 20, Allocate, + 1 << 14, 1 << 14, + }, + { + 15, 20, Allocate, + 1 << 22, 1 << 22, + }, + { + 15, 20, Panic, + 1 << 16, 1 << 16, + }, + { + 15, 20, Panic, + 1 << 14, -1, + }, + { + 15, 20, Panic, + 1 << 22, -1, + }, + { + 15, 20, Panic, + (1 << 15) + 23, 1 << 16, + }, + { + 15, 20, Panic, + 0, -1, // yep, 0 always works, even + }, + { + 15, 20, Allocate, + 0, 0, + }, + { + 15, 20, AllocateSmaller, + 1 << 14, 1 << 14, + }, + { + 15, 20, AllocateSmaller, + 1 << 22, -1, + }, + } + + for i := range tcs { + tc := tcs[i] + t.Run(fmt.Sprintf("[%d,%d] behav=%s Get(%d) exp=%d", tc.poolMinShift, tc.poolMaxShift, tc.behavior, tc.get, tc.expShiftBufLen), func(t *testing.T) { + pool := New(tc.poolMinShift, tc.poolMaxShift, tc.behavior) + if tc.expShiftBufLen == -1 { + assert.Panics(t, func() { + pool.Get(tc.get) + }) + return + } + buf := pool.Get(tc.get) + assert.True(t, uint(len(buf.Bytes())) == tc.get) + assert.True(t, int64(len(buf.shiftBuf)) == tc.expShiftBufLen) + }) + } +} + +func TestFittingShift(t *testing.T) { + assert.Equal(t, uint(16), fittingShift(1+1<<15)) + assert.Equal(t, uint(15), fittingShift(1<<15)) +} + +func TestFreeFromPoolRangeDoesNotPanic(t *testing.T) { + pool := New(15, 20, Allocate) + buf := pool.Get(1 << 16) + assert.NotPanics(t, func() { + buf.Free() + }) +} + +func TestFreeFromOutOfPoolRangeDoesNotPanic(t *testing.T) { + pool := New(15, 20, Allocate) + buf := pool.Get(1 << 23) + assert.NotPanics(t, func() { + buf.Free() + }) +} diff --git a/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go b/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go new file mode 100644 index 0000000..18f70b8 --- /dev/null +++ b/rpc/dataconn/base2bufpool/nofitbehavior_enumer.go @@ -0,0 +1,51 @@ +// Code generated by "enumer -type NoFitBehavior"; DO NOT EDIT. + +package base2bufpool + +import ( + "fmt" +) + +const _NoFitBehaviorName = "PanicAllocateSmallerAllocateLargerAllocate" + +var _NoFitBehaviorIndex = [...]uint8{0, 5, 20, 34, 42} + +func (i NoFitBehavior) String() string { + if i >= NoFitBehavior(len(_NoFitBehaviorIndex)-1) { + return fmt.Sprintf("NoFitBehavior(%d)", i) + } + return _NoFitBehaviorName[_NoFitBehaviorIndex[i]:_NoFitBehaviorIndex[i+1]] +} + +var _NoFitBehaviorValues = []NoFitBehavior{0, 1, 2, 3} + +var _NoFitBehaviorNameToValueMap = map[string]NoFitBehavior{ + _NoFitBehaviorName[0:5]: 0, + _NoFitBehaviorName[5:20]: 1, + _NoFitBehaviorName[20:34]: 2, + _NoFitBehaviorName[34:42]: 3, +} + +// NoFitBehaviorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func NoFitBehaviorString(s string) (NoFitBehavior, error) { + if val, ok := _NoFitBehaviorNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to NoFitBehavior values", s) +} + +// NoFitBehaviorValues returns all values of the enum +func NoFitBehaviorValues() []NoFitBehavior { + return _NoFitBehaviorValues +} + +// IsANoFitBehavior returns "true" if the value is listed in the enum definition. "false" otherwise +func (i NoFitBehavior) IsANoFitBehavior() bool { + for _, v := range _NoFitBehaviorValues { + if i == v { + return true + } + } + return false +} diff --git a/rpc/dataconn/dataconn_client.go b/rpc/dataconn/dataconn_client.go new file mode 100644 index 0000000..a12292b --- /dev/null +++ b/rpc/dataconn/dataconn_client.go @@ -0,0 +1,215 @@ +package dataconn + +import ( + "bytes" + "context" + "fmt" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/zfs" +) + +type Client struct { + log Logger + cn transport.Connecter +} + +func NewClient(connecter transport.Connecter, log Logger) *Client { + return &Client{ + log: log, + cn: connecter, + } +} + +func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error { + + var buf bytes.Buffer + _, memErr := buf.WriteString(endpoint) + if memErr != nil { + panic(memErr) + } + if err := conn.WriteStreamedMessage(ctx, &buf, ReqHeader); err != nil { + return err + } + + protobufBytes, err := proto.Marshal(req) + if err != nil { + return err + } + protobuf := bytes.NewBuffer(protobufBytes) + if err := conn.WriteStreamedMessage(ctx, protobuf, ReqStructured); err != nil { + return err + } + + if streamCopier != nil { + return conn.SendStream(ctx, streamCopier, ZFSStream) + } else { + return nil + } +} + +type RemoteHandlerError struct { + msg string +} + +func (e *RemoteHandlerError) Error() string { + return fmt.Sprintf("server error: %s", e.msg) +} + +type ProtocolError struct { + cause error +} + +func (e *ProtocolError) Error() string { + return fmt.Sprintf("protocol error: %s", e) +} + +func (c *Client) recv(ctx context.Context, conn *stream.Conn, res proto.Message) error { + + headerBuf, err := conn.ReadStreamedMessage(ctx, ResponseHeaderMaxSize, ResHeader) + if err != nil { + return err + } + header := string(headerBuf) + if strings.HasPrefix(header, responseHeaderHandlerErrorPrefix) { + // FIXME distinguishable error type + return &RemoteHandlerError{strings.TrimPrefix(header, responseHeaderHandlerErrorPrefix)} + } + if !strings.HasPrefix(header, responseHeaderHandlerOk) { + return &ProtocolError{fmt.Errorf("invalid header: %q", header)} + } + + protobuf, err := conn.ReadStreamedMessage(ctx, ResponseStructuredMaxSize, ResStructured) + if err != nil { + return err + } + if err := proto.Unmarshal(protobuf, res); err != nil { + return &ProtocolError{fmt.Errorf("cannot unmarshal structured part of response: %s", err)} + } + return nil +} + +func (c *Client) getWire(ctx context.Context) (*stream.Conn, error) { + nc, err := c.cn.Connect(ctx) + if err != nil { + return nil, err + } + conn := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout) + return conn, nil +} + +func (c *Client) putWire(conn *stream.Conn) { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing connection") + } +} + +func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + conn, err := c.getWire(ctx) + if err != nil { + return nil, nil, err + } + putWireOnReturn := true + defer func() { + if putWireOnReturn { + c.putWire(conn) + } + }() + + if err := c.send(ctx, conn, EndpointSend, req, nil); err != nil { + return nil, nil, err + } + + var res pdu.SendRes + if err := c.recv(ctx, conn, &res); err != nil { + return nil, nil, err + } + + var copier zfs.StreamCopier = nil + if !req.DryRun { + putWireOnReturn = false + copier = &streamCopier{streamConn: conn, closeStreamOnClose: true} + } + + return &res, copier, nil +} + +func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { + + defer c.log.Info("ReqRecv returns") + conn, err := c.getWire(ctx) + if err != nil { + return nil, err + } + + // send and recv response concurrently to catch early exists of remote handler + // (e.g. disk full, permission error, etc) + + type recvRes struct { + res *pdu.ReceiveRes + err error + } + recvErrChan := make(chan recvRes) + go func() { + res := &pdu.ReceiveRes{} + if err := c.recv(ctx, conn, res); err != nil { + recvErrChan <- recvRes{res, err} + } else { + recvErrChan <- recvRes{res, nil} + } + }() + + sendErrChan := make(chan error) + go func() { + if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil { + sendErrChan <- err + } else { + sendErrChan <- nil + } + }() + + var res recvRes + var sendErr error + var cause error // one of the above + didTryClose := false + for i := 0; i < 2; i++ { + select { + case res = <-recvErrChan: + c.log.WithField("errType", fmt.Sprintf("%T", res.err)).WithError(res.err).Debug("recv goroutine returned") + if res.err != nil && cause == nil { + cause = res.err + } + case sendErr = <-sendErrChan: + c.log.WithField("errType", fmt.Sprintf("%T", sendErr)).WithError(sendErr).Debug("send goroutine returned") + if sendErr != nil && cause == nil { + cause = sendErr + } + } + if !didTryClose && (res.err != nil || sendErr != nil) { + didTryClose = true + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("ReqRecv: cannot close connection, will likely block indefinitely") + } + c.log.WithError(err).Debug("ReqRecv: closed connection, should trigger other goroutine error") + } + } + + if !didTryClose { + // didn't close it in above loop, so we can give it back + c.putWire(conn) + } + + // if receive failed with a RemoteHandlerError, we know the transport was not broken + // => take the remote error as cause for the operation to fail + // TODO combine errors if send also failed + // (after all, send could have crashed on our side, rendering res.err a mere symptom of the cause) + if _, ok := res.err.(*RemoteHandlerError); ok { + cause = res.err + } + + return res.res, cause +} diff --git a/rpc/dataconn/dataconn_debug.go b/rpc/dataconn/dataconn_debug.go new file mode 100644 index 0000000..3c20701 --- /dev/null +++ b/rpc/dataconn/dataconn_debug.go @@ -0,0 +1,20 @@ +package dataconn + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/dataconn_server.go b/rpc/dataconn/dataconn_server.go new file mode 100644 index 0000000..41f5781 --- /dev/null +++ b/rpc/dataconn/dataconn_server.go @@ -0,0 +1,178 @@ +package dataconn + +import ( + "bytes" + "context" + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/zfs" +) + +// WireInterceptor has a chance to exchange the context and connection on each client connection. +type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (context.Context, *transport.AuthConn) + +// Handler implements the functionality that is exposed by Server to the Client. +type Handler interface { + // Send handles a SendRequest. + // The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run. + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) + // Receive handles a ReceiveRequest. + // It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout + // configured in ServerConfig.Shared.IdleConnTimeout. + Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) +} + +type Logger = logger.Logger + +type Server struct { + h Handler + wi WireInterceptor + log Logger +} + +func NewServer(wi WireInterceptor, logger Logger, handler Handler) *Server { + return &Server{ + h: handler, + wi: wi, + log: logger, + } +} + +// Serve consumes the listener, closes it as soon as ctx is closed. +// No accept errors are returned: they are logged to the Logger passed +// to the constructor. +func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { + + go func() { + <-ctx.Done() + s.log.Debug("context done") + if err := l.Close(); err != nil { + s.log.WithError(err).Error("cannot close listener") + } + }() + conns := make(chan *transport.AuthConn) + go func() { + for { + conn, err := l.Accept(ctx) + if err != nil { + if ctx.Done() != nil { + s.log.Debug("stop accepting after context is done") + return + } + s.log.WithError(err).Error("accept error") + continue + } + conns <- conn + } + }() + for conn := range conns { + go s.serveConn(conn) + } +} + +func (s *Server) serveConn(nc *transport.AuthConn) { + s.log.Debug("serveConn begin") + defer s.log.Debug("serveConn done") + + ctx := context.Background() + if s.wi != nil { + ctx, nc = s.wi(ctx, nc) + } + + c := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout) + defer func() { + s.log.Debug("close client connection") + if err := c.Close(); err != nil { + s.log.WithError(err).Error("cannot close client connection") + } + }() + + header, err := c.ReadStreamedMessage(ctx, RequestHeaderMaxSize, ReqHeader) + if err != nil { + s.log.WithError(err).Error("error reading structured part") + return + } + endpoint := string(header) + + reqStructured, err := c.ReadStreamedMessage(ctx, RequestStructuredMaxSize, ReqStructured) + if err != nil { + s.log.WithError(err).Error("error reading structured part") + return + } + + s.log.WithField("endpoint", endpoint).Debug("calling handler") + + var res proto.Message + var sendStream zfs.StreamCopier + var handlerErr error + switch endpoint { + case EndpointSend: + var req pdu.SendReq + if err := proto.Unmarshal(reqStructured, &req); err != nil { + s.log.WithError(err).Error("cannot unmarshal send request") + return + } + res, sendStream, handlerErr = s.h.Send(ctx, &req) // SHADOWING + case EndpointRecv: + var req pdu.ReceiveReq + if err := proto.Unmarshal(reqStructured, &req); err != nil { + s.log.WithError(err).Error("cannot unmarshal receive request") + return + } + res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING + default: + s.log.WithField("endpoint", endpoint).Error("unknown endpoint") + handlerErr = fmt.Errorf("requested endpoint does not exist") + return + } + + s.log.WithField("endpoint", endpoint).WithField("errType", fmt.Sprintf("%T", handlerErr)).Debug("handler returned") + + // prepare protobuf now to return the protobuf error in the header + // if marshaling fails. We consider failed marshaling a handler error + var protobuf *bytes.Buffer + if handlerErr == nil { + protobufBytes, err := proto.Marshal(res) + if err != nil { + s.log.WithError(err).Error("cannot marshal handler protobuf") + handlerErr = err + } + protobuf = bytes.NewBuffer(protobufBytes) // SHADOWING + } + + var resHeaderBuf bytes.Buffer + if handlerErr == nil { + resHeaderBuf.WriteString(responseHeaderHandlerOk) + } else { + resHeaderBuf.WriteString(responseHeaderHandlerErrorPrefix) + resHeaderBuf.WriteString(handlerErr.Error()) + } + if err := c.WriteStreamedMessage(ctx, &resHeaderBuf, ResHeader); err != nil { + s.log.WithError(err).Error("cannot write response header") + return + } + + if handlerErr != nil { + s.log.Debug("early exit after handler error") + return + } + + if err := c.WriteStreamedMessage(ctx, protobuf, ResStructured); err != nil { + s.log.WithError(err).Error("cannot write structured part of response") + return + } + + if sendStream != nil { + err := c.SendStream(ctx, sendStream, ZFSStream) + if err != nil { + s.log.WithError(err).Error("cannot write send stream") + } + } + + return +} diff --git a/rpc/dataconn/dataconn_shared.go b/rpc/dataconn/dataconn_shared.go new file mode 100644 index 0000000..0ea5a34 --- /dev/null +++ b/rpc/dataconn/dataconn_shared.go @@ -0,0 +1,70 @@ +package dataconn + +import ( + "io" + "sync" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/stream" + "github.com/zrepl/zrepl/zfs" +) + +const ( + EndpointSend string = "/v1/send" + EndpointRecv string = "/v1/recv" +) + +const ( + ReqHeader uint32 = 1 + iota + ReqStructured + ResHeader + ResStructured + ZFSStream +) + +// Note that changing theses constants may break interop with other clients +// Aggressive with timing, conservative (future compatible) with buffer sizes +const ( + HeartbeatInterval = 5 * time.Second + HeartbeatPeerTimeout = 10 * time.Second + RequestHeaderMaxSize = 1 << 15 + RequestStructuredMaxSize = 1 << 22 + ResponseHeaderMaxSize = 1 << 15 + ResponseStructuredMaxSize = 1 << 23 +) + +// the following are protocol constants +const ( + responseHeaderHandlerOk = "HANDLER OK\n" + responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n" +) + +type streamCopier struct { + mtx sync.Mutex + used bool + streamConn *stream.Conn + closeStreamOnClose bool +} + +// WriteStreamTo implements zfs.StreamCopier +func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + s.mtx.Lock() + defer s.mtx.Unlock() + if s.used { + panic("streamCopier used mulitple times") + } + s.used = true + return s.streamConn.ReadStreamInto(w, ZFSStream) +} + +// Close implements zfs.StreamCopier +func (s *streamCopier) Close() error { + // only record the close here, what we do actually depends on whether + // the streamCopier is instantiated server-side or client-side + s.mtx.Lock() + defer s.mtx.Unlock() + if s.closeStreamOnClose { + return s.streamConn.Close() + } + return nil +} diff --git a/rpc/dataconn/dataconn_test.go b/rpc/dataconn/dataconn_test.go new file mode 100644 index 0000000..4d8820e --- /dev/null +++ b/rpc/dataconn/dataconn_test.go @@ -0,0 +1 @@ +package dataconn diff --git a/rpc/dataconn/frameconn/frameconn.go b/rpc/dataconn/frameconn/frameconn.go new file mode 100644 index 0000000..2736b98 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn.go @@ -0,0 +1,346 @@ +package frameconn + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" +) + +type FrameHeader struct { + Type uint32 + PayloadLen uint32 +} + +// The 4 MSBs of ft are reserved for frameconn. +func IsPublicFrameType(ft uint32) bool { + return (0xf<<28)&ft == 0 +} + +const ( + rstFrameType uint32 = 1<<28 + iota +) + +func assertPublicFrameType(frameType uint32) { + if !IsPublicFrameType(frameType) { + panic(fmt.Sprintf("frameconn: frame type %v cannot be used by consumers of this package", frameType)) + } +} + +func (f *FrameHeader) Unmarshal(buf []byte) { + if len(buf) != 8 { + panic(fmt.Sprintf("frame header is 8 bytes long")) + } + f.Type = binary.BigEndian.Uint32(buf[0:4]) + f.PayloadLen = binary.BigEndian.Uint32(buf[4:8]) +} + +type Conn struct { + readMtx, writeMtx sync.Mutex + nc timeoutconn.Conn + ncBuf *bufio.ReadWriter + readNextValid bool + readNext FrameHeader + nextReadErr error + bufPool *base2bufpool.Pool // no need for sync around it + shutdown shutdownFSM +} + +func Wrap(nc timeoutconn.Conn) *Conn { + return &Conn{ + nc: nc, + // ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)), + bufPool: base2bufpool.New(15, 22, base2bufpool.Allocate), // FIXME switch to Panic, but need to enforce the limits in recv for that. => need frameconn config + readNext: FrameHeader{}, + readNextValid: false, + } +} + +var ErrReadFrameLengthShort = errors.New("read frame length too short") +var ErrFixedFrameLengthMismatch = errors.New("read frame length mismatch") + +type Buffer struct { + bufpoolBuffer base2bufpool.Buffer + payloadLen uint32 +} + +func (b *Buffer) Free() { + b.bufpoolBuffer.Free() +} + +func (b *Buffer) Bytes() []byte { + return b.bufpoolBuffer.Bytes()[0:b.payloadLen] +} + +type Frame struct { + Header FrameHeader + Buffer Buffer +} + +var ErrShutdown = fmt.Errorf("frameconn: shutting down") + +// ReadFrame reads a frame from the connection. +// +// Due to an internal optimization (Readv, specifically), it is not guaranteed that a single call to +// WriteFrame unblocks a pending ReadFrame on an otherwise idle (empty) connection. +// The only way to guarantee that all previously written frames can reach the peer's layers on top +// of frameconn is to send an empty frame (no payload) and to ignore empty frames on the receiving side. +func (c *Conn) ReadFrame() (Frame, error) { + + if c.shutdown.IsShuttingDown() { + return Frame{}, ErrShutdown + } + + // only aquire readMtx now to prioritize the draining in Shutdown() + // over external callers (= drain public callers) + + c.readMtx.Lock() + defer c.readMtx.Unlock() + f, err := c.readFrame() + if f.Header.Type == rstFrameType { + c.shutdown.Begin() + return Frame{}, ErrShutdown + } + return f, err +} + +// callers must have readMtx locked +func (c *Conn) readFrame() (Frame, error) { + + if c.nextReadErr != nil { + ret := c.nextReadErr + c.nextReadErr = nil + return Frame{}, ret + } + + if !c.readNextValid { + var buf [8]byte + if _, err := io.ReadFull(c.nc, buf[:]); err != nil { + return Frame{}, err + } + c.readNext.Unmarshal(buf[:]) + c.readNextValid = true + } + + // read payload + next header + var nextHdrBuf [8]byte + buffer := c.bufPool.Get(uint(c.readNext.PayloadLen)) + bufferBytes := buffer.Bytes() + + if c.readNext.PayloadLen == 0 { + // This if statement implements the unlock-by-sending-empty-frame behavior + // documented in ReadFrame's public docs. + // + // It is crucial that we return this empty frame now: + // Consider the following plot with x-axis being time, + // P being a frame with payload, E one without, X either of P or E + // + // P P P P P P P E.....................X + // | | | | + // | | | F3 + // | | | + // | F2 |signficant time between frames because + // F1 the peer has nothing to say to us + // + // Assume we're at the point were F2's header is in c.readNext. + // That means F2 has not yet been returned. + // But because it is empty (no payload), we're already done reading it. + // If we omitted this if statement, the following would happen: + // Readv below would read [][]byte{[len(0)], [len(8)]). + + c.readNextValid = false + frame := Frame{ + Header: c.readNext, + Buffer: Buffer{ + bufpoolBuffer: buffer, + payloadLen: c.readNext.PayloadLen, // 0 + }, + } + return frame, nil + } + + noNextHeader := false + if n, err := c.nc.ReadvFull([][]byte{bufferBytes, nextHdrBuf[:]}); err != nil { + noNextHeader = true + zeroPayloadAndPeerClosed := n == 0 && c.readNext.PayloadLen == 0 && err == io.EOF + zeroPayloadAndNextFrameHeaderThenPeerClosed := err == io.EOF && c.readNext.PayloadLen == 0 && n == int64(len(nextHdrBuf)) + nonzeroPayloadRecvdButNextHeaderMissing := n > 0 && uint32(n) == c.readNext.PayloadLen + if zeroPayloadAndPeerClosed || zeroPayloadAndNextFrameHeaderThenPeerClosed || nonzeroPayloadRecvdButNextHeaderMissing { + // This is the last frame on the conn. + // Store the error to be returned on the next invocation of ReadFrame. + c.nextReadErr = err + // NORETURN, this frame is still valid + } else { + return Frame{}, err + } + } + + frame := Frame{ + Header: c.readNext, + Buffer: Buffer{ + bufpoolBuffer: buffer, + payloadLen: c.readNext.PayloadLen, + }, + } + + if !noNextHeader { + c.readNext.Unmarshal(nextHdrBuf[:]) + c.readNextValid = true + } else { + c.readNextValid = false + } + + return frame, nil +} + +func (c *Conn) WriteFrame(payload []byte, frameType uint32) error { + assertPublicFrameType(frameType) + if c.shutdown.IsShuttingDown() { + return ErrShutdown + } + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + return c.writeFrame(payload, frameType) +} + +func (c *Conn) writeFrame(payload []byte, frameType uint32) error { + var hdrBuf [8]byte + binary.BigEndian.PutUint32(hdrBuf[0:4], frameType) + binary.BigEndian.PutUint32(hdrBuf[4:8], uint32(len(payload))) + bufs := net.Buffers([][]byte{hdrBuf[:], payload}) + if _, err := c.nc.WritevFull(bufs); err != nil { + return err + } + return nil +} + +func (c *Conn) Shutdown(deadline time.Time) error { + // TCP connection teardown is a bit wonky if we are in a situation + // where there is still data in flight (DIF) to our side: + // If we just close the connection, our kernel will send RSTs + // in response to the DIF, and those RSTs may reach the client's + // kernel faster than the client app is able to pull the + // last bytes from its kernel TCP receive buffer. + // + // Therefore, we send a frame with type rstFrameType to indicate + // that the connection is to be closed immediately, and further + // use CloseWrite instead of Close. + // As per definition of the wire interface, CloseWrite guarantees + // delivery of the data in our kernel TCP send buffer. + // Therefore, the client always receives the RST frame. + // + // Now what are we going to do after that? + // + // 1. Naive Option: We just call Close() right after CloseWrite. + // This yields the same race condition as explained above (DIF, first + // paragraph): The situation just becomae a little more unlikely because + // our rstFrameType + CloseWrite dance gave the client a full RTT worth of + // time to read the data from its TCP recv buffer. + // + // 2. Correct Option: Drain the read side until io.EOF + // We can read from the unclosed read-side of the connection until we get + // the io.EOF caused by the (well behaved) client closing the connection + // in response to it reading the rstFrameType frame we sent. + // However, this wastes resources on our side (we don't care about the + // pending data anymore), and has potential for (D)DoS through CPU-time + // exhaustion if the client just keeps sending data. + // Then again, this option has the advantage with well-behaved clients + // that we do not waste precious kernel-memory on the stale receive buffer + // on our side (which is still full of data that we do not intend to read). + // + // 2.1 DoS Mitigation: Bound the number of bytes to drain, then close + // At the time of writing, this technique is practiced by the Go http server + // implementation, and actually SHOULDed in the HTTP 1.1 RFC. It is + // important to disable the idle timeout of the underlying timeoutconn in + // that case and set an absolute deadline by which the socket must have + // been fully drained. Not too hard, though ;) + // + // 2.2: Client sends RST, not FIN when it receives an rstFrameTyp frame. + // We can use wire.(*net.TCPConn).SetLinger(0) to force an RST to be sent + // on a subsequent close (instead of a FIN + wait for FIN+ACK). + // TODO put this into Wire interface as an abstract method. + // + // 2.3 Only start draining after N*RTT + // We have an RTT approximation from Wire.CloseWrite, which by definition + // must not return before all to-be-sent-data has been acknowledged by the + // client. Give the client a fair chance to react, and only start draining + // after a multiple of the RTT has elapsed. + // We waste the recv buffer memory a little longer than necessary, iff the + // client reacts faster than expected. But we don't wast CPU time. + // If we apply 2.2, we'll also have the benefit that our kernel will have + // dropped the recv buffer memory as soon as it receives the client's RST. + // + // 3. TCP-only: OOB-messaging + // We can use TCP's 'urgent' flag in the client to acknowledge the receipt + // of the rstFrameType to us. + // We can thus wait for that signal while leaving the kernel buffer as is. + + // TODO: For now, we just drain the connection (Option 2), + // but we enforce deadlines so the _time_ we drain the connection + // is bounded, although we do _that_ at full speed + + defer prometheus.NewTimer(prom.ShutdownSeconds).ObserveDuration() + + closeWire := func(step string) error { + // TODO SetLinger(0) or similiar (we want RST frames here, not FINS) + if closeErr := c.nc.Close(); closeErr != nil { + prom.ShutdownCloseErrors.WithLabelValues("close").Inc() + return closeErr + } + return nil + } + + hardclose := func(err error, step string) error { + prom.ShutdownHardCloses.WithLabelValues(step).Inc() + return closeWire(step) + } + + c.shutdown.Begin() + // new calls to c.ReadFrame and c.WriteFrame will now return ErrShutdown + // Aquiring writeMtx and readMtx ensures that the last calls exit successfully + + // disable renewing timeouts now, enforce the requested deadline instead + // we need to do this before aquiring locks to enforce the timeout on slow + // clients / if something hangs (DoS mitigation) + if err := c.nc.DisableTimeouts(); err != nil { + return hardclose(err, "disable_timeouts") + } + if err := c.nc.SetDeadline(deadline); err != nil { + return hardclose(err, "set_deadline") + } + + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + + if err := c.writeFrame([]byte{}, rstFrameType); err != nil { + return hardclose(err, "write_frame") + } + + if err := c.nc.CloseWrite(); err != nil { + return hardclose(err, "close_write") + } + + c.readMtx.Lock() + defer c.readMtx.Unlock() + + // TODO DoS mitigation: wait for client acknowledgement that they initiated Shutdown, + // then perform abortive close on our side. As explained above, probably requires + // OOB signaling such as TCP's urgent flag => transport-specific? + + // TODO DoS mitigation by reading limited number of bytes + // see discussion above why this is non-trivial + defer prometheus.NewTimer(prom.ShutdownDrainSeconds).ObserveDuration() + n, _ := io.Copy(ioutil.Discard, c.nc) + prom.ShutdownDrainBytesRead.Observe(float64(n)) + + return closeWire("close") +} diff --git a/rpc/dataconn/frameconn/frameconn_prometheus.go b/rpc/dataconn/frameconn/frameconn_prometheus.go new file mode 100644 index 0000000..d5c2c50 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_prometheus.go @@ -0,0 +1,63 @@ +package frameconn + +import "github.com/prometheus/client_golang/prometheus" + +var prom struct { + ShutdownDrainBytesRead prometheus.Summary + ShutdownSeconds prometheus.Summary + ShutdownDrainSeconds prometheus.Summary + ShutdownHardCloses *prometheus.CounterVec + ShutdownCloseErrors *prometheus.CounterVec +} + +func init() { + prom.ShutdownDrainBytesRead = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_drain_bytes_read", + Help: "Number of bytes read during the drain phase of connection shutdown", + }) + prom.ShutdownSeconds = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_seconds", + Help: "Seconds it took for connection shutdown to complete", + }) + prom.ShutdownDrainSeconds = prometheus.NewSummary(prometheus.SummaryOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_drain_seconds", + Help: "Seconds it took from read-side-drain until shutdown completion", + }) + prom.ShutdownHardCloses = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_hard_closes", + Help: "Number of hard connection closes during shutdown (abortive close)", + }, []string{"step"}) + prom.ShutdownCloseErrors = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "zrepl", + Subsystem: "frameconn", + Name: "shutdown_close_errors", + Help: "Number of errors closing the underlying network connection. Should alert on this", + }, []string{"step"}) +} + +func PrometheusRegister(registry prometheus.Registerer) error { + if err := registry.Register(prom.ShutdownDrainBytesRead); err != nil { + return err + } + if err := registry.Register(prom.ShutdownSeconds); err != nil { + return err + } + if err := registry.Register(prom.ShutdownDrainSeconds); err != nil { + return err + } + if err := registry.Register(prom.ShutdownHardCloses); err != nil { + return err + } + if err := registry.Register(prom.ShutdownCloseErrors); err != nil { + return err + } + return nil +} diff --git a/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go b/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go new file mode 100644 index 0000000..980d980 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_shutdown_fsm.go @@ -0,0 +1,37 @@ +package frameconn + +import "sync" + +type shutdownFSM struct { + mtx sync.Mutex + state shutdownFSMState +} + +type shutdownFSMState uint32 + +const ( + shutdownStateOpen shutdownFSMState = iota + shutdownStateBegin +) + +func newShutdownFSM() *shutdownFSM { + fsm := &shutdownFSM{ + state: shutdownStateOpen, + } + return fsm +} + +func (f *shutdownFSM) Begin() (thisCallStartedShutdown bool) { + f.mtx.Lock() + defer f.mtx.Unlock() + thisCallStartedShutdown = f.state != shutdownStateOpen + f.state = shutdownStateBegin + return thisCallStartedShutdown +} + +func (f *shutdownFSM) IsShuttingDown() bool { + f.mtx.Lock() + defer f.mtx.Unlock() + return f.state != shutdownStateOpen +} + diff --git a/rpc/dataconn/frameconn/frameconn_test.go b/rpc/dataconn/frameconn/frameconn_test.go new file mode 100644 index 0000000..b070c79 --- /dev/null +++ b/rpc/dataconn/frameconn/frameconn_test.go @@ -0,0 +1,22 @@ +package frameconn + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsPublicFrameType(t *testing.T) { + for i := uint32(0); i < 256; i++ { + i := i + t.Run(fmt.Sprintf("^%d", i), func(t *testing.T) { + assert.False(t, IsPublicFrameType(^i)) + }) + } + assert.True(t, IsPublicFrameType(0)) + assert.True(t, IsPublicFrameType(1)) + assert.True(t, IsPublicFrameType(255)) + assert.False(t, IsPublicFrameType(rstFrameType)) +} + diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn.go b/rpc/dataconn/heartbeatconn/heartbeatconn.go new file mode 100644 index 0000000..2924fdc --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn.go @@ -0,0 +1,137 @@ +package heartbeatconn + +import ( + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" +) + +type Conn struct { + state state + // if not nil, opErr is returned for ReadFrame and WriteFrame (not for Close, though) + opErr atomic.Value // error + fc *frameconn.Conn + sendInterval, timeout time.Duration + stopSend chan struct{} + lastFrameSent atomic.Value // time.Time +} + +type HeartbeatTimeout struct{} + +func (e HeartbeatTimeout) Error() string { + return "heartbeat timeout" +} + +func (e HeartbeatTimeout) Temporary() bool { return true } + +func (e HeartbeatTimeout) Timeout() bool { return true } + +var _ net.Error = HeartbeatTimeout{} + +type state = int32 + +const ( + stateInitial state = 0 + stateClosed state = 2 +) + +const ( + heartbeat uint32 = 1 << 24 +) + +// The 4 MSBs of ft are reserved for frameconn, we reserve the next 4 MSB for us. +func IsPublicFrameType(ft uint32) bool { + return frameconn.IsPublicFrameType(ft) && (0xf<<24)&ft == 0 +} + +func assertPublicFrameType(frameType uint32) { + if !IsPublicFrameType(frameType) { + panic(fmt.Sprintf("heartbeatconn: frame type %v cannot be used by consumers of this package", frameType)) + } +} + +func Wrap(nc timeoutconn.Wire, sendInterval, timeout time.Duration) *Conn { + c := &Conn{ + fc: frameconn.Wrap(timeoutconn.Wrap(nc, timeout)), + stopSend: make(chan struct{}), + sendInterval: sendInterval, + timeout: timeout, + } + c.lastFrameSent.Store(time.Now()) + go c.sendHeartbeats() + return c +} + +func (c *Conn) Shutdown() error { + normalClose := atomic.CompareAndSwapInt32(&c.state, stateInitial, stateClosed) + if normalClose { + close(c.stopSend) + } + return c.fc.Shutdown(time.Now().Add(c.timeout)) +} + +// started as a goroutine in constructor +func (c *Conn) sendHeartbeats() { + sleepTime := func(now time.Time) time.Duration { + lastSend := c.lastFrameSent.Load().(time.Time) + return lastSend.Add(c.sendInterval).Sub(now) + } + timer := time.NewTimer(sleepTime(time.Now())) + defer timer.Stop() + for { + select { + case <-c.stopSend: + return + case now := <-timer.C: + func() { + defer func() { + timer.Reset(sleepTime(time.Now())) + }() + if sleepTime(now) > 0 { + return + } + debug("send heartbeat") + // if the connection is in zombie mode (aka iptables DROP inbetween peers) + // this call or one of its successors will block after filling up the kernel tx buffer + c.fc.WriteFrame([]byte{}, heartbeat) + // ignore errors from WriteFrame to rate-limit SendHeartbeat retries + c.lastFrameSent.Store(time.Now()) + }() + } + } +} + +func (c *Conn) ReadFrame() (frameconn.Frame, error) { + return c.readFrameFiltered() +} + +func (c *Conn) readFrameFiltered() (frameconn.Frame, error) { + for { + f, err := c.fc.ReadFrame() + if err != nil { + return frameconn.Frame{}, err + } + if IsPublicFrameType(f.Header.Type) { + return f, nil + } + if f.Header.Type != heartbeat { + return frameconn.Frame{}, fmt.Errorf("unknown frame type %x", f.Header.Type) + } + // drop heartbeat frame + debug("received heartbeat") + continue + } +} + +func (c *Conn) WriteFrame(payload []byte, frameType uint32) error { + assertPublicFrameType(frameType) + err := c.fc.WriteFrame(payload, frameType) + if err == nil { + c.lastFrameSent.Store(time.Now()) + } + return err +} diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go b/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go new file mode 100644 index 0000000..6bdea8d --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn_debug.go @@ -0,0 +1,20 @@ +package heartbeatconn + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_HEARTBEATCONN_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn/heartbeatconn: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/heartbeatconn/heartbeatconn_test.go b/rpc/dataconn/heartbeatconn/heartbeatconn_test.go new file mode 100644 index 0000000..8b02459 --- /dev/null +++ b/rpc/dataconn/heartbeatconn/heartbeatconn_test.go @@ -0,0 +1,26 @@ +package heartbeatconn + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" +) + +func TestFrameTypes(t *testing.T) { + assert.True(t, frameconn.IsPublicFrameType(heartbeat)) +} + +func TestNegativeTimer(t *testing.T) { + + timer := time.NewTimer(-1 * time.Second) + defer timer.Stop() + time.Sleep(100 * time.Millisecond) + select { + case <-timer.C: + t.Log("timer with negative time fired, that's what we want") + default: + t.Fail() + } +} diff --git a/rpc/dataconn/microbenchmark/microbenchmark.go b/rpc/dataconn/microbenchmark/microbenchmark.go new file mode 100644 index 0000000..5ee58cf --- /dev/null +++ b/rpc/dataconn/microbenchmark/microbenchmark.go @@ -0,0 +1,135 @@ +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "io" + "net" + "os" + + "github.com/pkg/profile" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/replication/pdu" +) + +func orDie(err error) { + if err != nil { + panic(err) + } +} + +type devNullHandler struct{} + +func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { + var res pdu.SendRes + return &res, os.Stdin, nil +} + +func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream io.Reader) (*pdu.ReceiveRes, error) { + var buf [1<<15]byte + _, err := io.CopyBuffer(os.Stdout, stream, buf[:]) + var res pdu.ReceiveRes + return &res, err +} + +type tcpConnecter struct { + net, addr string +} + +func (c tcpConnecter) Connect(ctx context.Context) (net.Conn, error) { + return net.Dial(c.net, c.addr) +} + +var args struct { + addr string + appmode string + direction string + profile bool +} + +func server() { + + log := logger.NewStderrDebugLogger() + log.Debug("starting server") + l, err := net.Listen("tcp", args.addr) + orDie(err) + + srvConfig := dataconn.ServerConfig{ + Shared: dataconn.SharedConfig { + MaxProtoLen: 4096, + MaxHeaderLen: 4096, + SendChunkSize: 1 << 17, + MaxRecvChunkSize: 1 << 17, + }, + } + srv := dataconn.NewServer(devNullHandler{}, srvConfig, nil) + + ctx := context.Background() + ctx = dataconn.WithLogger(ctx, log) + srv.Serve(ctx, l) + +} + +func main() { + + flag.BoolVar(&args.profile, "profile", false, "") + flag.StringVar(&args.addr, "address", ":8888", "") + flag.StringVar(&args.appmode, "appmode", "client|server", "") + flag.StringVar(&args.direction, "direction", "", "send|recv") + flag.Parse() + + if args.profile { + defer profile.Start(profile.CPUProfile).Stop() + } + + switch args.appmode { + case "client": + client() + case "server": + server() + default: + orDie(fmt.Errorf("unknown appmode %q", args.appmode)) + } +} + +func client() { + + logger := logger.NewStderrDebugLogger() + ctx := context.Background() + ctx = dataconn.WithLogger(ctx, logger) + + clientConfig := dataconn.ClientConfig{ + Shared: dataconn.SharedConfig { + MaxProtoLen: 4096, + MaxHeaderLen: 4096, + SendChunkSize: 1 << 17, + MaxRecvChunkSize: 1 << 17, + }, + } + orDie(clientConfig.Validate()) + + connecter := tcpConnecter{"tcp", args.addr} + client := dataconn.NewClient(connecter, clientConfig) + + switch args.direction { + case "send": + req := pdu.SendReq{} + _, stream, err := client.ReqSendStream(ctx, &req) + orDie(err) + var buf [1<<15]byte + _, err = io.CopyBuffer(os.Stdout, stream, buf[:]) + orDie(err) + case "recv": + var buf bytes.Buffer + buf.WriteString("teststreamtobereceived") + req := pdu.ReceiveReq{} + _, err := client.ReqRecv(ctx, &req, os.Stdin) + orDie(err) + default: + orDie(fmt.Errorf("unknown direction%q", args.direction)) + } + +} diff --git a/rpc/dataconn/stream/stream.go b/rpc/dataconn/stream/stream.go new file mode 100644 index 0000000..10e7ff0 --- /dev/null +++ b/rpc/dataconn/stream/stream.go @@ -0,0 +1,269 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "strings" + "unicode/utf8" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" + "github.com/zrepl/zrepl/rpc/dataconn/frameconn" + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/zfs" +) + +type Logger = logger.Logger + +type contextKey int + +const ( + contextKeyLogger contextKey = 1 + iota +) + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLogger, log) +} + +func getLog(ctx context.Context) Logger { + log, ok := ctx.Value(contextKeyLogger).(Logger) + if !ok { + log = logger.NewNullLogger() + } + return log +} + +// Frame types used by this package. +// 4 MSBs are reserved for frameconn, next 4 MSB for heartbeatconn, next 4 MSB for us. +const ( + StreamErrTrailer uint32 = 1 << (16 + iota) + End + // max 16 +) + +// NOTE: make sure to add a tests for each frame type that checks +// whether it is heartbeatconn.IsPublicFrameType() + +// Check whether the given frame type is allowed to be used by +// consumers of this package. Intended for use in unit tests. +func IsPublicFrameType(ft uint32) bool { + return frameconn.IsPublicFrameType(ft) && heartbeatconn.IsPublicFrameType(ft) && ((0xf<<16)&ft == 0) +} + +const FramePayloadShift = 19 + +var bufpool = base2bufpool.New(FramePayloadShift, FramePayloadShift, base2bufpool.Panic) + +// if sendStream returns an error, that error will be sent as a trailer to the client +// ok will return nil, though. +func writeStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) { + debug("writeStream: enter stype=%v", stype) + defer debug("writeStream: return") + if stype == 0 { + panic("stype must be non-zero") + } + if !IsPublicFrameType(stype) { + panic(fmt.Sprintf("stype %v is not public", stype)) + } + return doWriteStream(ctx, c, stream, stype) +} + +func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) { + + // RULE1 (buf == ) XOR (err == nil) + type read struct { + buf base2bufpool.Buffer + err error + } + + reads := make(chan read, 5) + go func() { + for { + buffer := bufpool.Get(1 << FramePayloadShift) + bufferBytes := buffer.Bytes() + n, err := io.ReadFull(stream, bufferBytes) + buffer.Shrink(uint(n)) + // if we received anything, send one read without an error (RULE 1) + if n > 0 { + reads <- read{buffer, nil} + } + if err == io.ErrUnexpectedEOF { + // happens iff io.ReadFull read io.EOF from stream + err = io.EOF + } + if err != nil { + reads <- read{err: err} // RULE1 + close(reads) + return + } + } + }() + + for read := range reads { + if read.err == nil { + // RULE 1: read.buf is valid + // next line is the hot path... + writeErr := c.WriteFrame(read.buf.Bytes(), stype) + read.buf.Free() + if writeErr != nil { + return nil, writeErr + } + continue + } else if read.err == io.EOF { + if err := c.WriteFrame([]byte{}, End); err != nil { + return nil, err + } + break + } else { + errReader := strings.NewReader(read.err.Error()) + errReadErrReader, errConnWrite := doWriteStream(ctx, c, errReader, StreamErrTrailer) + if errReadErrReader != nil { + panic(errReadErrReader) // in-memory, cannot happen + } + return read.err, errConnWrite + } + } + + return nil, nil +} + +type ReadStreamErrorKind int + +const ( + ReadStreamErrorKindConn ReadStreamErrorKind = 1 + iota + ReadStreamErrorKindWrite + ReadStreamErrorKindSource + ReadStreamErrorKindStreamErrTrailerEncoding + ReadStreamErrorKindUnexpectedFrameType +) + +type ReadStreamError struct { + Kind ReadStreamErrorKind + Err error +} + +func (e *ReadStreamError) Error() string { + kindStr := "" + switch e.Kind { + case ReadStreamErrorKindConn: + kindStr = " read error: " + case ReadStreamErrorKindWrite: + kindStr = " write error: " + case ReadStreamErrorKindSource: + kindStr = " source error: " + case ReadStreamErrorKindStreamErrTrailerEncoding: + kindStr = " source implementation error: " + case ReadStreamErrorKindUnexpectedFrameType: + kindStr = " protocol error: " + } + return fmt.Sprintf("stream:%s%s", kindStr, e.Err) +} + +var _ net.Error = &ReadStreamError{} + +func (e ReadStreamError) netErr() net.Error { + if netErr, ok := e.Err.(net.Error); ok { + return netErr + } + return nil +} + +func (e ReadStreamError) Timeout() bool { + if netErr := e.netErr(); netErr != nil { + return netErr.Timeout() + } + return false +} + +func (e ReadStreamError) Temporary() bool { + if netErr := e.netErr(); netErr != nil { + return netErr.Temporary() + } + return false +} + +var _ zfs.StreamCopierError = &ReadStreamError{} + +func (e ReadStreamError) IsReadError() bool { + return e.Kind != ReadStreamErrorKindWrite +} + +func (e ReadStreamError) IsWriteError() bool { + return e.Kind == ReadStreamErrorKindWrite +} + +type readFrameResult struct { + f frameconn.Frame + err error +} + +func readFrames(reads chan<- readFrameResult, c *heartbeatconn.Conn) { + for { + var r readFrameResult + r.f, r.err = c.ReadFrame() + reads <- r + if r.err != nil { + return + } + } +} + +// ReadStream will close c if an error reading from c or writing to receiver occurs +// +// readStream calls itself recursively to read multi-frame error trailers +// Thus, the reads channel needs to be a parameter. +func readStream(reads <-chan readFrameResult, c *heartbeatconn.Conn, receiver io.Writer, stype uint32) *ReadStreamError { + + var f frameconn.Frame + for read := range reads { + debug("readStream: read frame %v %v", read.f.Header, read.err) + f = read.f + if read.err != nil { + return &ReadStreamError{ReadStreamErrorKindConn, read.err} + } + if f.Header.Type != stype { + break + } + + n, err := receiver.Write(f.Buffer.Bytes()) + if err != nil { + f.Buffer.Free() + return &ReadStreamError{ReadStreamErrorKindWrite, err} // FIXME wrap as writer error + } + if n != len(f.Buffer.Bytes()) { + f.Buffer.Free() + return &ReadStreamError{ReadStreamErrorKindWrite, io.ErrShortWrite} + } + f.Buffer.Free() + } + + if f.Header.Type == End { + debug("readStream: End reached") + return nil + } + + if f.Header.Type == StreamErrTrailer { + debug("readStream: begin of StreamErrTrailer") + var errBuf bytes.Buffer + if n, err := errBuf.Write(f.Buffer.Bytes()); n != len(f.Buffer.Bytes()) || err != nil { + panic(fmt.Sprintf("unexpected bytes.Buffer write error: %v %v", n, err)) + } + // recursion ftw! we won't enter this if stmt because stype == StreamErrTrailer in the following call + rserr := readStream(reads, c, &errBuf, StreamErrTrailer) + if rserr != nil && rserr.Kind == ReadStreamErrorKindWrite { + panic(fmt.Sprintf("unexpected bytes.Buffer write error: %s", rserr)) + } else if rserr != nil { + debug("readStream: rserr != nil && != ReadStreamErrorKindWrite: %v %v\n", rserr.Kind, rserr.Err) + return rserr + } + if !utf8.Valid(errBuf.Bytes()) { + return &ReadStreamError{ReadStreamErrorKindStreamErrTrailerEncoding, fmt.Errorf("source error, but not encoded as UTF-8")} + } + return &ReadStreamError{ReadStreamErrorKindSource, fmt.Errorf("%s", errBuf.String())} + } + + return &ReadStreamError{ReadStreamErrorKindUnexpectedFrameType, fmt.Errorf("unexpected frame type %v (expected %v)", f.Header.Type, stype)} +} diff --git a/rpc/dataconn/stream/stream_conn.go b/rpc/dataconn/stream/stream_conn.go new file mode 100644 index 0000000..ce0c052 --- /dev/null +++ b/rpc/dataconn/stream/stream_conn.go @@ -0,0 +1,194 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/zfs" +) + +type Conn struct { + hc *heartbeatconn.Conn + + // whether the per-conn readFrames goroutine completed + waitReadFramesDone chan struct{} + // filled by per-conn readFrames goroutine + frameReads chan readFrameResult + + // readMtx serializes read stream operations because we inherently only + // support a single stream at a time over hc. + readMtx sync.Mutex + readClean bool + allowWriteStreamTo bool + + // writeMtx serializes write stream operations because we inherently only + // support a single stream at a time over hc. + writeMtx sync.Mutex + writeClean bool +} + +var readMessageSentinel = fmt.Errorf("read stream complete") + +type writeStreamToErrorUnknownState struct{} + +func (e writeStreamToErrorUnknownState) Error() string { + return "dataconn read stream: connection is in unknown state" +} + +func (e writeStreamToErrorUnknownState) IsReadError() bool { return true } + +func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false } + +func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn { + hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout) + conn := &Conn{ + hc: hc, readClean: true, writeClean: true, + waitReadFramesDone: make(chan struct{}), + frameReads: make(chan readFrameResult, 5), // FIXME constant + } + go conn.readFrames() + return conn +} + +func isConnCleanAfterRead(res *ReadStreamError) bool { + return res == nil || res.Kind == ReadStreamErrorKindSource || res.Kind == ReadStreamErrorKindStreamErrTrailerEncoding +} + +func isConnCleanAfterWrite(err error) bool { + return err == nil +} + +var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped") + +func (c *Conn) readFrames() { + defer close(c.waitReadFramesDone) + defer close(c.frameReads) + readFrames(c.frameReads, c.hc) +} + +func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) { + c.readMtx.Lock() + defer c.readMtx.Unlock() + if !c.readClean { + return nil, &ReadStreamError{ + Kind: ReadStreamErrorKindConn, + Err: fmt.Errorf("dataconn read message: connection is in unknown state"), + } + } + + r, w := io.Pipe() + var buf bytes.Buffer + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + lr := io.LimitReader(r, int64(maxSize)) + if _, err := io.Copy(&buf, lr); err != nil && err != readMessageSentinel { + panic(err) + } + }() + err := readStream(c.frameReads, c.hc, w, frameType) + c.readClean = isConnCleanAfterRead(err) + w.CloseWithError(readMessageSentinel) + wg.Wait() + if err != nil { + return nil, err + } else { + return buf.Bytes(), nil + } +} + +// WriteStreamTo reads a stream from Conn and writes it to w. +func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError { + c.readMtx.Lock() + defer c.readMtx.Unlock() + if !c.readClean { + return writeStreamToErrorUnknownState{} + } + var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) + c.readClean = isConnCleanAfterRead(err) + + // https://golang.org/doc/faq#nil_error + if err == nil { + return nil + } + return err +} + +func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error { + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + if !c.writeClean { + return fmt.Errorf("dataconn write message: connection is in unknown state") + } + errBuf, errConn := writeStream(ctx, c.hc, buf, frameType) + if errBuf != nil { + panic(errBuf) + } + c.writeClean = isConnCleanAfterWrite(errConn) + return errConn +} + +func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error { + c.writeMtx.Lock() + defer c.writeMtx.Unlock() + if !c.writeClean { + return fmt.Errorf("dataconn send stream: connection is in unknown state") + } + + // avoid io.Pipe if zfs.StreamCopier is an io.Reader + var r io.Reader + var w *io.PipeWriter + streamCopierErrChan := make(chan zfs.StreamCopierError, 1) + if reader, ok := src.(io.Reader); ok { + r = reader + streamCopierErrChan <- nil + close(streamCopierErrChan) + } else { + r, w = io.Pipe() + go func() { + streamCopierErrChan <- src.WriteStreamTo(w) + w.Close() + }() + } + + type writeStreamRes struct { + errStream, errConn error + } + writeStreamErrChan := make(chan writeStreamRes, 1) + go func() { + var res writeStreamRes + res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType) + if w != nil { + w.CloseWithError(res.errStream) + } + writeStreamErrChan <- res + }() + + writeRes := <-writeStreamErrChan + streamCopierErr := <-streamCopierErrChan + c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct? + if streamCopierErr != nil && streamCopierErr.IsReadError() { + return streamCopierErr // something on our side is bad + } else { + if writeRes.errStream != nil { + return writeRes.errStream + } else if writeRes.errConn != nil { + return writeRes.errConn + } + // TODO combined error? + return streamCopierErr + } +} + +func (c *Conn) Close() error { + err := c.hc.Shutdown() + <-c.waitReadFramesDone + return err +} diff --git a/rpc/dataconn/stream/stream_debug.go b/rpc/dataconn/stream/stream_debug.go new file mode 100644 index 0000000..c1e2ef9 --- /dev/null +++ b/rpc/dataconn/stream/stream_debug.go @@ -0,0 +1,20 @@ +package stream + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DATACONN_STREAM_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc/dataconn/stream: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/dataconn/stream/stream_test.go b/rpc/dataconn/stream/stream_test.go new file mode 100644 index 0000000..7da4cfa --- /dev/null +++ b/rpc/dataconn/stream/stream_test.go @@ -0,0 +1,131 @@ +package stream + +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" + "github.com/zrepl/zrepl/util/socketpair" +) + +func TestFrameTypesOk(t *testing.T) { + t.Logf("%v", End) + assert.True(t, heartbeatconn.IsPublicFrameType(End)) + assert.True(t, heartbeatconn.IsPublicFrameType(StreamErrTrailer)) +} + +func TestStreamer(t *testing.T) { + + anc, bnc, err := socketpair.SocketPair() + require.NoError(t, err) + + hto := 1 * time.Hour + a := heartbeatconn.Wrap(anc, hto, hto) + b := heartbeatconn.Wrap(bnc, hto, hto) + + log := logger.NewStderrDebugLogger() + ctx := WithLogger(context.Background(), log) + + stype := uint32(0x23) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var buf bytes.Buffer + buf.Write( + bytes.Repeat([]byte{1, 2}, 1<<25), + ) + writeStream(ctx, a, &buf, stype) + log.Debug("WriteStream returned") + a.Shutdown() + }() + + go func() { + defer wg.Done() + var buf bytes.Buffer + ch := make(chan readFrameResult, 5) + wg.Add(1) + go func() { + defer wg.Done() + readFrames(ch, b) + }() + err := readStream(ch, b, &buf, stype) + log.WithField("errType", fmt.Sprintf("%T %v", err, err)).Debug("ReadStream returned") + assert.Nil(t, err) + expected := bytes.Repeat([]byte{1, 2}, 1<<25) + assert.True(t, bytes.Equal(expected, buf.Bytes())) + b.Shutdown() + }() + + wg.Wait() + +} + +type errReader struct { + t *testing.T + readErr error +} + +func (er errReader) Read(p []byte) (n int, err error) { + er.t.Logf("errReader.Read called") + return 0, er.readErr +} + +func TestMultiFrameStreamErrTraileror(t *testing.T) { + anc, bnc, err := socketpair.SocketPair() + require.NoError(t, err) + + hto := 1 * time.Hour + a := heartbeatconn.Wrap(anc, hto, hto) + b := heartbeatconn.Wrap(bnc, hto, hto) + + log := logger.NewStderrDebugLogger() + ctx := WithLogger(context.Background(), log) + + longErr := fmt.Errorf("an error that definitley spans more than one frame:\n%s", strings.Repeat("a\n", 1<<4)) + + stype := uint32(0x23) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r := errReader{t, longErr} + writeStream(ctx, a, &r, stype) + a.Shutdown() + }() + + go func() { + defer wg.Done() + defer b.Shutdown() + var buf bytes.Buffer + ch := make(chan readFrameResult, 5) + wg.Add(1) + go func() { + defer wg.Done() + readFrames(ch, b) + }() + err := readStream(ch, b, &buf, stype) + t.Logf("%s", err) + require.NotNil(t, err) + assert.True(t, buf.Len() == 0) + assert.Equal(t, err.Kind, ReadStreamErrorKindSource) + receivedErr := err.Err.Error() + expectedErr := longErr.Error() + assert.True(t, receivedErr == expectedErr) // builtin Equals is too slow + if receivedErr != expectedErr { + t.Logf("lengths: %v %v", len(receivedErr), len(expectedErr)) + } + }() + + wg.Wait() +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore new file mode 100644 index 0000000..650f629 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore @@ -0,0 +1,12 @@ +# setup-specific +inventory +*.retry + +# generated by gen_files.sh +files/*ssh_client_identity +files/*ssh_client_identity.pub +files/*.tls.*.key +files/*.tls.*.csr +files/*.tls.*.crt +files/wireevaluator + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md new file mode 100644 index 0000000..958c539 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/README.md @@ -0,0 +1,15 @@ +This directory contains very hacky test automation for wireevaluator based on nested Ansible playbooks. + +* Copy `inventory.example` to `inventory` +* Adjust `inventory` IP addresses as needed +* Make sure there's an OpenSSH server running on the serve host +* Make sure there's no firewalling whatsoever between the hosts +* Run `GENKEYS=1 ./gen_files.sh` to re-generate self-signed TLS certs +* Run the following command, adjusting the `wireevaluator_repeat` value to the number of times you want to repeat each test + +``` +ansible-playbook -i inventory all.yml -e `wireevaluator_repeat=3` +``` + +Generally, things are fine if the playbook doesn't show any panics from wireevaluator. + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml new file mode 100644 index 0000000..9abe411 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/all.yml @@ -0,0 +1,17 @@ +- hosts: connect,serve + tasks: + + - name: "run test" + include: internal_prepare_and_run_repeated.yml + wireevaluator_transport: "{{config.0}}" + wireevaluator_case: "{{config.1}}" + wireevaluator_repeat: "{{wireevaluator_repeat}}" + with_cartesian: + - [ tls, ssh, tcp ] + - + - closewrite_server + - closewrite_client + - readdeadline_server + - readdeadline_client + loop_control: + loop_var: config diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh new file mode 100755 index 0000000..e7a72e9 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +set -e + +cd "$( dirname "${BASH_SOURCE[0]}")" + +FILESDIR="$(pwd)"/files + +echo "[INFO] compile binary" +pushd .. >/dev/null +go build -o $FILESDIR/wireevaluator +popd >/dev/null + +if [ "$GENKEYS" == "" ]; then + echo "[INFO] GENKEYS environment variable not set, assumed to be valid" + exit 0 +fi + +echo "[INFO] gen ssh key" +ssh-keygen -f "$FILESDIR/wireevaluator.ssh_client_identity" -t ed25519 + +echo "[INFO] gen tls keys" + +cakey="$FILESDIR/wireevaluator.tls.ca.key" +cacrt="$FILESDIR/wireevaluator.tls.ca.crt" +hostprefix="$FILESDIR/wireevaluator.tls" + +openssl genrsa -out "$cakey" 4096 +openssl req -x509 -new -nodes -key "$cakey" -sha256 -days 1 -out "$cacrt" + +declare -a HOSTS +HOSTS+=("theserver") +HOSTS+=("theclient") + +for host in "${HOSTS[@]}"; do + key="${hostprefix}.${host}.key" + csr="${hostprefix}.${host}.csr" + crt="${hostprefix}.${host}.crt" + openssl genrsa -out "$key" 2048 + + ( + echo "." + echo "." + echo "." + echo "." + echo "." + echo $host + echo "." + echo "." + echo "." + echo "." + ) | openssl req -new -key "$key" -out "$csr" + + openssl x509 -req -in "$csr" -CA "$cacrt" -CAkey "$cakey" -CAcreateserial -out "$crt" -days 1 -sha256 + +done diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml new file mode 100644 index 0000000..71d9074 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_prepare_and_run_repeated.yml @@ -0,0 +1,54 @@ +--- + +- name: compile binary and any key files required + local_action: command ./gen_files.sh + +- name: Kill test binary + shell: "killall -9 wireevaluator || true" +- name: Deploy new binary + copy: + src: "files/wireevaluator" + dest: "/opt/wireevaluator" + mode: 0755 + +- set_fact: + wireevaluator_connect_ip: "{{hostvars['connect'].ansible_host}}" + wireevaluator_serve_ip: "{{hostvars['serve'].ansible_host}}" + +- name: Deploy config + template: + src: "templates/{{wireevaluator_transport}}.yml.j2" + dest: "/opt/wireevaluator.yml" + +- name: Deploy client identity + copy: + src: "files/wireevaluator.{{item}}" + dest: "/opt/wireevaluator.{{item}}" + mode: 0400 + with_items: + - ssh_client_identity + - ssh_client_identity.pub + - tls.ca.key + - tls.ca.crt + - tls.theserver.key + - tls.theserver.crt + - tls.theclient.key + - tls.theclient.crt + +- name: Setup server ssh client identity access + when: inventory_hostname == "serve" + block: + - authorized_key: + user: root + state: present + key: "{{ lookup('file', 'files/wireevaluator.ssh_client_identity.pub') }}" + key_options: 'command="/opt/wireevaluator -mode stdinserver -config /opt/wireevaluator.yml client1"' + - file: + state: directory + mode: 0700 + path: /tmp/wireevaluator_stdinserver + +- name: repeated test + include: internal_run_test_prepared_single.yml + with_sequence: start=1 end={{wireevaluator_repeat}} + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml new file mode 100644 index 0000000..4c535d9 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/internal_run_test_prepared_single.yml @@ -0,0 +1,38 @@ +--- + +- debug: + msg: "run test transport={{wireevaluator_transport}} case={{wireevaluator_case}} repeatedly" + +- name: Run Server + when: inventory_hostname == "serve" + command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode serve -testcase {{wireevaluator_case}} + register: spawn_servers + async: 60 + poll: 0 + +- name: Run Client + when: inventory_hostname == "connect" + command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode connect -testcase {{wireevaluator_case}} + register: spawn_clients + async: 60 + poll: 0 + +- name: Wait for server shutdown + when: inventory_hostname == "serve" + async_status: + jid: "{{ spawn_servers.ansible_job_id}}" + delay: 0.5 + retries: 10 + +- name: Wait for client shutdown + when: inventory_hostname == "connect" + async_status: + jid: "{{ spawn_clients.ansible_job_id}}" + delay: 0.5 + retries: 10 + +- name: Wait for connections to die (TIME_WAIT conns) + command: sleep 4 + changed_when: false + + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example new file mode 100644 index 0000000..70cf1c7 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/inventory.example @@ -0,0 +1,2 @@ +connect ansible_user=root ansible_host=192.168.122.128 wireevaluator_mode="connect" +serve ansible_user=root ansible_host=192.168.122.129 wireevaluator_mode="serve" diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 new file mode 100644 index 0000000..559f0d0 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/ssh.yml.j2 @@ -0,0 +1,13 @@ +connect: + type: ssh+stdinserver + host: {{wireevaluator_serve_ip}} + user: root + port: 22 + identity_file: /opt/wireevaluator.ssh_client_identity + options: # optional, default [], `-o` arguments passed to ssh + - "Compression=yes" +serve: + type: stdinserver + client_identities: + - "client1" + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 new file mode 100644 index 0000000..3aaa29e --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tcp.yml.j2 @@ -0,0 +1,10 @@ +connect: + type: tcp + address: "{{wireevaluator_serve_ip}}:8888" +serve: + type: tcp + listen: ":8888" + clients: { + "{{wireevaluator_connect_ip}}" : "client1" + } + diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 new file mode 100644 index 0000000..6fdaa94 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/templates/tls.yml.j2 @@ -0,0 +1,16 @@ +connect: + type: tls + address: "{{wireevaluator_serve_ip}}:8888" + ca: "/opt/wireevaluator.tls.ca.crt" + cert: "/opt/wireevaluator.tls.theclient.crt" + key: "/opt/wireevaluator.tls.theclient.key" + server_cn: "theserver" + +serve: + type: tls + listen: ":8888" + ca: "/opt/wireevaluator.tls.ca.crt" + cert: "/opt/wireevaluator.tls.theserver.crt" + key: "/opt/wireevaluator.tls.theserver.key" + client_cns: + - "theclient" diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go new file mode 100644 index 0000000..65ea9cd --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go @@ -0,0 +1,111 @@ +// a tool to test whether a given transport implements the timeoutconn.Wire interface +package main + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "os" + "path" + + netssh "github.com/problame/go-netssh" + "github.com/zrepl/yaml-config" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + transportconfig "github.com/zrepl/zrepl/transport/fromconfig" +) + +func noerror(err error) { + if err != nil { + panic(err) + } +} + +type Config struct { + Connect config.ConnectEnum + Serve config.ServeEnum +} + +var args struct { + mode string + configPath string + testCase string +} + +var conf Config + +type TestCase interface { + Client(wire transport.Wire) + Server(wire transport.Wire) +} + +func main() { + flag.StringVar(&args.mode, "mode", "", "connect|serve") + flag.StringVar(&args.configPath, "config", "", "config file path") + flag.StringVar(&args.testCase, "testcase", "", "") + flag.Parse() + + bytes, err := ioutil.ReadFile(args.configPath) + noerror(err) + err = yaml.UnmarshalStrict(bytes, &conf) + noerror(err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + global := &config.Global{ + Serve: &config.GlobalServe{ + StdinServer: &config.GlobalStdinServer{ + SockDir: "/tmp/wireevaluator_stdinserver", + }, + }, + } + + switch args.mode { + case "connect": + tc, err := getTestCase(args.testCase) + noerror(err) + connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect) + noerror(err) + wire, err := connecter.Connect(ctx) + noerror(err) + tc.Client(wire) + case "serve": + tc, err := getTestCase(args.testCase) + noerror(err) + lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve) + noerror(err) + l, err := lf() + noerror(err) + conn, err := l.Accept(ctx) + noerror(err) + tc.Server(conn) + case "stdinserver": + identity := flag.Arg(0) + unixaddr := path.Join(global.Serve.StdinServer.SockDir, identity) + err := netssh.Proxy(ctx, unixaddr) + if err == nil { + os.Exit(0) + } + panic(err) + default: + panic(fmt.Sprintf("unknown mode %q", args.mode)) + } + +} + +func getTestCase(tcName string) (TestCase, error) { + switch tcName { + case "closewrite_server": + return &CloseWrite{mode: CloseWriteServerSide}, nil + case "closewrite_client": + return &CloseWrite{mode: CloseWriteClientSide}, nil + case "readdeadline_client": + return &Deadlines{mode: DeadlineModeClientTimeout}, nil + case "readdeadline_server": + return &Deadlines{mode: DeadlineModeServerTimeout}, nil + default: + return nil, fmt.Errorf("unknown test case %q", tcName) + } +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go new file mode 100644 index 0000000..cc1907e --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go @@ -0,0 +1,110 @@ +package main + +import ( + "bytes" + "io" + "io/ioutil" + "log" + + "github.com/zrepl/zrepl/transport" +) + +type CloseWriteMode uint + +const ( + CloseWriteClientSide CloseWriteMode = 1 + iota + CloseWriteServerSide +) + +type CloseWrite struct { + mode CloseWriteMode +} + +// sent repeatedly +var closeWriteTestSendData = bytes.Repeat([]byte{0x23, 0x42}, 1<<24) +var closeWriteErrorMsg = []byte{0xb, 0xa, 0xd, 0xf, 0x0, 0x0, 0xd} + +func (m CloseWrite) Client(wire transport.Wire) { + switch m.mode { + case CloseWriteClientSide: + m.receiver(wire) + case CloseWriteServerSide: + m.sender(wire) + default: + panic(m.mode) + } +} + +func (m CloseWrite) Server(wire transport.Wire) { + switch m.mode { + case CloseWriteClientSide: + m.sender(wire) + case CloseWriteServerSide: + m.receiver(wire) + default: + panic(m.mode) + } +} + +func (CloseWrite) sender(wire transport.Wire) { + defer func() { + closeErr := wire.Close() + log.Printf("closeErr=%T %s", closeErr, closeErr) + }() + + type opResult struct { + err error + } + writeDone := make(chan struct{}, 1) + go func() { + close(writeDone) + for { + _, err := wire.Write(closeWriteTestSendData) + if err != nil { + return + } + } + }() + + defer func() { + <-writeDone + }() + + var respBuf bytes.Buffer + _, err := io.Copy(&respBuf, wire) + if err != nil { + log.Fatalf("should have received io.EOF, which is masked by io.Copy, got: %s", err) + } + if !bytes.Equal(respBuf.Bytes(), closeWriteErrorMsg) { + log.Fatalf("did not receive error message, got response with len %v:\n%v", respBuf.Len(), respBuf.Bytes()) + } + +} + +func (CloseWrite) receiver(wire transport.Wire) { + + // consume half the test data, then detect an error, send it and CloseWrite + + r := io.LimitReader(wire, int64(5 * len(closeWriteTestSendData)/3)) + _, err := io.Copy(ioutil.Discard, r) + noerror(err) + + var errBuf bytes.Buffer + errBuf.Write(closeWriteErrorMsg) + _, err = io.Copy(wire, &errBuf) + noerror(err) + + err = wire.CloseWrite() + noerror(err) + + // drain wire, as documented in transport.Wire, this is the only way we know the client closed the conn + _, err = io.Copy(ioutil.Discard, wire) + if err != nil { + // io.Copy masks io.EOF to nil, and we expect io.EOF from the client's Close() call + log.Panicf("unexpected error returned from reading conn: %s", err) + } + + closeErr := wire.Close() + log.Printf("closeErr=%T %s", closeErr, closeErr) + +} diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go new file mode 100644 index 0000000..397e9d5 --- /dev/null +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_deadlines.go @@ -0,0 +1,138 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "log" + "net" + "time" + + "github.com/zrepl/zrepl/transport" +) + +type DeadlineMode uint + +const ( + DeadlineModeClientTimeout DeadlineMode = 1 + iota + DeadlineModeServerTimeout +) + +type Deadlines struct { + mode DeadlineMode +} + +func (d Deadlines) Client(wire transport.Wire) { + switch d.mode { + case DeadlineModeClientTimeout: + d.sleepThenSend(wire) + case DeadlineModeServerTimeout: + d.sendThenRead(wire) + default: + panic(d.mode) + } +} + +func (d Deadlines) Server(wire transport.Wire) { + switch d.mode { + case DeadlineModeClientTimeout: + d.sendThenRead(wire) + case DeadlineModeServerTimeout: + d.sleepThenSend(wire) + default: + panic(d.mode) + } +} + +var deadlinesTimeout = 1 * time.Second + +func (d Deadlines) sleepThenSend(wire transport.Wire) { + defer wire.Close() + + log.Print("sleepThenSend") + + // exceed timeout of peer (do not respond to their hi msg) + time.Sleep(3 * deadlinesTimeout) + // expect that the client has hung up on us by now + err := d.sendMsg(wire, "hi") + log.Printf("err=%s", err) + log.Printf("err=%#v", err) + if err == nil { + log.Panic("no error") + } + if _, ok := err.(net.Error); !ok { + log.Panic("not a net error") + } + +} + +func (d Deadlines) sendThenRead(wire transport.Wire) { + + log.Print("sendThenRead") + + err := d.sendMsg(wire, "hi") + noerror(err) + + err = wire.SetReadDeadline(time.Now().Add(deadlinesTimeout)) + noerror(err) + + m, err := d.recvMsg(wire) + log.Printf("m=%q", m) + log.Printf("err=%s", err) + log.Printf("err=%#v", err) + + // close asap so that the peer get's a 'connection reset by peer' error or similar + closeErr := wire.Close() + if closeErr != nil { + panic(closeErr) + } + + var neterr net.Error + var ok bool + if err == nil { + goto unexpErr // works for nil, too + } + neterr, ok = err.(net.Error) + if !ok { + log.Println("not a net error") + goto unexpErr + } + if !neterr.Timeout() { + log.Println("not a timeout") + } + + return + +unexpErr: + panic(fmt.Sprintf("sendThenRead: client should have hung up but got error %T %s", err, err)) +} + +const deadlinesMsgLen = 40 + +func (d Deadlines) sendMsg(wire transport.Wire, msg string) error { + if len(msg) > deadlinesMsgLen { + panic(len(msg)) + } + var buf [deadlinesMsgLen]byte + copy(buf[:], []byte(msg)) + n, err := wire.Write(buf[:]) + if err != nil { + return err + } + if n != len(buf) { + panic("short write not allowed") + } + return nil +} + +func (d Deadlines) recvMsg(wire transport.Wire) (string, error) { + + var buf bytes.Buffer + r := io.LimitReader(wire, deadlinesMsgLen) + _, err := io.Copy(&buf, r) + if err != nil { + return "", err + } + return buf.String(), nil + +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn.go b/rpc/dataconn/timeoutconn/timeoutconn.go new file mode 100644 index 0000000..9e0a3bf --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn.go @@ -0,0 +1,288 @@ +// package timeoutconn wraps a Wire to provide idle timeouts +// based on Set{Read,Write}Deadline. +// Additionally, it exports abstractions for vectored I/O. +package timeoutconn + +// NOTE +// Readv and Writev are not split-off into a separate package +// because we use raw syscalls, bypassing Conn's Read / Write methods. + +import ( + "errors" + "io" + "net" + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +type Wire interface { + net.Conn + // A call to CloseWrite indicates that no further Write calls will be made to Wire. + // The implementation must return an error in case of Write calls after CloseWrite. + // On the peer's side, after it read all data written to Wire prior to the call to + // CloseWrite on our side, the peer's Read calls must return io.EOF. + // CloseWrite must not affect the read-direction of Wire: specifically, the + // peer must continue to be able to send, and our side must continue be + // able to receive data over Wire. + // + // Note that CloseWrite may (and most likely will) return sooner than the + // peer having received all data written to Wire prior to CloseWrite. + // Note further that buffering happening in the network stacks on either side + // mandates an explicit acknowledgement from the peer that the connection may + // be fully shut down: If we call Close without such acknowledgement, any data + // from peer to us that was already in flight may cause connection resets to + // be sent from us to the peer via the specific transport protocol. Those + // resets (e.g. RST frames) may erase all connection context on the peer, + // including data in its receive buffers. Thus, those resets are in race with + // a) transmission of data written prior to CloseWrite and + // b) the peer application reading from those buffers. + // + // The WaitForPeerClose method can be used to wait for connection termination, + // iff the implementation supports it. If it does not, the only reliable way + // to wait for a peer to have read all data from Wire (until io.EOF), is to + // expect it to close the wire at that point as well, and to drain Wire until + // we also read io.EOF. + CloseWrite() error + + // Wait for the peer to close the connection. + // No data that could otherwise be Read is lost as a consequence of this call. + // The use case for this API is abortive connection shutdown. + // To provide any value over draining Wire using io.Read, an implementation + // will likely use out-of-bounds messaging mechanisms. + // TODO WaitForPeerClose() (supported bool, err error) +} + +type Conn struct { + Wire + renewDeadlinesDisabled int32 + idleTimeout time.Duration +} + +func Wrap(conn Wire, idleTimeout time.Duration) Conn { + return Conn{Wire: conn, idleTimeout: idleTimeout} +} + +// DisableTimeouts disables the idle timeout behavior provided by this package. +// Existing deadlines are cleared iff the call is the first call to this method. +func (c *Conn) DisableTimeouts() error { + if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) { + return c.SetDeadline(time.Time{}) + } + return nil +} + +func (c *Conn) renewReadDeadline() error { + if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + return nil + } + return c.SetReadDeadline(time.Now().Add(c.idleTimeout)) +} + +func (c *Conn) renewWriteDeadline() error { + if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + return nil + } + return c.SetWriteDeadline(time.Now().Add(c.idleTimeout)) +} + +func (c Conn) Read(p []byte) (n int, err error) { + n = 0 + err = nil +restart: + if err := c.renewReadDeadline(); err != nil { + return n, err + } + var nCurRead int + nCurRead, err = c.Wire.Read(p[n:len(p)]) + n += nCurRead + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 { + err = nil + goto restart + } + return n, err +} + +func (c Conn) Write(p []byte) (n int, err error) { + n = 0 +restart: + if err := c.renewWriteDeadline(); err != nil { + return n, err + } + var nCurWrite int + nCurWrite, err = c.Wire.Write(p[n:len(p)]) + n += nCurWrite + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 { + err = nil + goto restart + } + return n, err +} + +// Writes the given buffers to Conn, following the sematincs of io.Copy, +// but is guaranteed to use the writev system call if the wrapped Wire +// support it. +// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers). +func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) { + n = 0 +restart: + if err := c.renewWriteDeadline(); err != nil { + return n, err + } + var nCurWrite int64 + nCurWrite, err = io.Copy(c.Wire, &bufs) + n += nCurWrite + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 { + err = nil + goto restart + } + return n, err +} + +var SyscallConnNotSupported = errors.New("SyscallConn not supported") + +// The interface that must be implemented for vectored I/O support. +// If the wrapped Wire does not implement it, a less efficient +// fallback implementation is used. +// Rest assured that Go's *net.TCPConn implements this interface. +type SyscallConner interface { + // The sentinel error value SyscallConnNotSupported can be returned + // if the support for SyscallConn depends on runtime conditions and + // that runtime condition is not met. + SyscallConn() (syscall.RawConn, error) +} + +var _ SyscallConner = (*net.TCPConn)(nil) + +func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) { + vecs = make([]syscall.Iovec, 0, len(buffers)) + for i := range buffers { + totalLen += int64(len(buffers[i])) + if len(buffers[i]) == 0 { + continue + } + vecs = append(vecs, syscall.Iovec{ + Base: &buffers[i][0], + Len: uint64(len(buffers[i])), + }) + } + return totalLen, vecs +} + +// Reads the given buffers full: +// Think of io.ReadvFull, but for net.Buffers + using the readv syscall. +// +// If the underlying Wire is not a SyscallConner, a fallback +// ipmlementation based on repeated Conn.Read invocations is used. +// +// If the connection returned io.EOF, the number of bytes up ritten until +// then + io.EOF is returned. This behavior is different to io.ReadFull +// which returns io.ErrUnexpectedEOF. +func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) { + totalLen, iovecs := buildIovecs(buffers) + if debugReadvNoShortReadsAssertEnable { + defer debugReadvNoShortReadsAssert(totalLen, n, err) + } + scc, ok := c.Wire.(SyscallConner) + if !ok { + return c.readvFallback(buffers) + } + raw, err := scc.SyscallConn() + if err == SyscallConnNotSupported { + return c.readvFallback(buffers) + } + if err != nil { + return 0, err + } + n, err = c.readv(raw, iovecs) + return +} + +func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) { + buffers := [][]byte(nbuffers) + for i := range buffers { + curBuf := buffers[i] + inner: + for len(curBuf) > 0 { + if err := c.renewReadDeadline(); err != nil { + return n, err + } + var oneN int + oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING + curBuf = curBuf[oneN:] + n += int64(oneN) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 { + continue inner + } + return n, err + } + } + } + return n, nil +} + +func (c Conn) readv(rawConn syscall.RawConn, iovecs []syscall.Iovec) (n int64, err error) { + for len(iovecs) > 0 { + if err := c.renewReadDeadline(); err != nil { + return n, err + } + oneN, oneErr := c.doOneReadv(rawConn, &iovecs) + n += oneN + if netErr, ok := oneErr.(net.Error); ok && netErr.Timeout() && oneN > 0 { // TODO likely not working + continue + } else if oneErr == nil && oneN > 0 { + continue + } else { + return n, oneErr + } + } + return n, nil +} + +func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) { + rawReadErr := rawConn.Read(func(fd uintptr) (done bool) { + // iovecs, n and err must not be shadowed! + thisReadN, _, errno := syscall.Syscall( + syscall.SYS_READV, + fd, + uintptr(unsafe.Pointer(&(*iovecs)[0])), + uintptr(len(*iovecs)), + ) + if thisReadN == ^uintptr(0) { + if errno == syscall.EAGAIN { + return false + } + err = syscall.Errno(errno) + return true + } + if int(thisReadN) < 0 { + panic("unexpected return value") + } + n += int64(thisReadN) + // shift iovecs forward + for left := int64(thisReadN); left > 0; { + curVecNewLength := int64((*iovecs)[0].Len) - left // TODO assert conversion + if curVecNewLength <= 0 { + left -= int64((*iovecs)[0].Len) + *iovecs = (*iovecs)[1:] + } else { + (*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left))) + (*iovecs)[0].Len = uint64(curVecNewLength) + break // inner + } + } + if thisReadN == 0 { + err = io.EOF + return true + } + return true + }) + + if rawReadErr != nil { + err = rawReadErr + } + + return n, err +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn_debug.go b/rpc/dataconn/timeoutconn/timeoutconn_debug.go new file mode 100644 index 0000000..d5a7cb3 --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn_debug.go @@ -0,0 +1,19 @@ +package timeoutconn + +import ( + "fmt" + "io" +) + +const debugReadvNoShortReadsAssertEnable = false + +func debugReadvNoShortReadsAssert(expectedLen, returnedLen int64, returnedErr error) { + readShort := expectedLen != returnedLen + if !readShort { + return + } + if returnedErr != io.EOF { + return + } + panic(fmt.Sprintf("ReadvFull short and error is not EOF%v\n", returnedErr)) +} diff --git a/rpc/dataconn/timeoutconn/timeoutconn_test.go b/rpc/dataconn/timeoutconn/timeoutconn_test.go new file mode 100644 index 0000000..708bba2 --- /dev/null +++ b/rpc/dataconn/timeoutconn/timeoutconn_test.go @@ -0,0 +1,177 @@ +package timeoutconn + +import ( + "bytes" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zrepl/zrepl/util/socketpair" +) + +func TestReadTimeout(t *testing.T) { + + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var buf bytes.Buffer + buf.WriteString("tooktoolong") + time.Sleep(500 * time.Millisecond) + _, err := io.Copy(a, &buf) + require.NoError(t, err) + }() + + go func() { + defer wg.Done() + conn := Wrap(b, 100*time.Millisecond) + buf := [4]byte{} // shorter than message put on wire + n, err := conn.Read(buf[:]) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) + }() + + wg.Wait() +} + +type writeBlockConn struct { + net.Conn + blockTime time.Duration +} + +func (c writeBlockConn) Write(p []byte) (int, error) { + time.Sleep(c.blockTime) + return c.Conn.Write(p) +} + +func (c writeBlockConn) CloseWrite() error { + return c.Conn.Close() +} + +func TestWriteTimeout(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + var buf bytes.Buffer + buf.WriteString("message") + blockConn := writeBlockConn{a, 500 * time.Millisecond} + conn := Wrap(blockConn, 100*time.Millisecond) + n, err := conn.Write(buf.Bytes()) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) +} + +func TestNoPartialReadsDueToDeadline(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + a.Write([]byte{1, 2, 3, 4, 5}) + // sleep to provoke a partial read in the consumer goroutine + time.Sleep(50 * time.Millisecond) + a.Write([]byte{6, 7, 8, 9, 10}) + }() + + go func() { + defer wg.Done() + bc := Wrap(b, 100*time.Millisecond) + var buf bytes.Buffer + beginRead := time.Now() + // io.Copy will encounter a partial read, then wait ~50ms until the other 5 bytes are written + // It is still going to fail with deadline err because it expects EOF + n, err := io.Copy(&buf, bc) + readDuration := time.Now().Sub(beginRead) + t.Logf("read duration=%s", readDuration) + t.Logf("recv done n=%v err=%v", n, err) + t.Logf("buf=%v", buf.Bytes()) + neterr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, neterr.Timeout()) + + assert.Equal(t, int64(10), n) + assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, buf.Bytes()) + // 50ms for the second read, 100ms after that one for the deadline + // allow for some jitter + assert.True(t, readDuration > 140*time.Millisecond) + assert.True(t, readDuration < 160*time.Millisecond) + }() + + wg.Wait() +} + +type partialWriteMockConn struct { + net.Conn // to satisfy interface + buf bytes.Buffer + writeDuration time.Duration + returnAfterBytesWritten int +} + +func newPartialWriteMockConn(writeDuration time.Duration, returnAfterBytesWritten int) *partialWriteMockConn { + return &partialWriteMockConn{ + writeDuration: writeDuration, + returnAfterBytesWritten: returnAfterBytesWritten, + } +} + +func (c *partialWriteMockConn) Write(p []byte) (int, error) { + time.Sleep(c.writeDuration) + consumeBytes := len(p) + if consumeBytes > c.returnAfterBytesWritten { + consumeBytes = c.returnAfterBytesWritten + } + n, err := c.buf.Write(p[0:consumeBytes]) + if err != nil || n != consumeBytes { + panic("bytes.Buffer behaves unexpectedly") + } + return n, nil +} + +func TestPartialWriteMockConn(t *testing.T) { + mc := newPartialWriteMockConn(100*time.Millisecond, 5) + buf := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + begin := time.Now() + n, err := mc.Write(buf[:]) + duration := time.Now().Sub(begin) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.True(t, duration > 100*time.Millisecond) + assert.True(t, duration < 150*time.Millisecond) +} + +func TestNoPartialWritesDueToDeadline(t *testing.T) { + a, b, err := socketpair.SocketPair() + require.NoError(t, err) + defer a.Close() + defer b.Close() + var buf bytes.Buffer + buf.WriteString("message") + blockConn := writeBlockConn{a, 150 * time.Millisecond} + conn := Wrap(blockConn, 100*time.Millisecond) + n, err := conn.Write(buf.Bytes()) + assert.Equal(t, 0, n) + assert.Error(t, err) + netErr, ok := err.(net.Error) + require.True(t, ok) + assert.True(t, netErr.Timeout()) +} diff --git a/rpc/grpcclientidentity/authlistener_grpc_adaptor.go b/rpc/grpcclientidentity/authlistener_grpc_adaptor.go new file mode 100644 index 0000000..accd124 --- /dev/null +++ b/rpc/grpcclientidentity/authlistener_grpc_adaptor.go @@ -0,0 +1,118 @@ +// Package grpcclientidentity makes the client identity +// provided by github.com/zrepl/zrepl/daemon/transport/serve.{AuthenticatedListener,AuthConn} +// available to gRPC service handlers. +// +// This goal is achieved through the combination of custom gRPC transport credentials and two interceptors +// (i.e. middleware). +// +// For gRPC clients, the TransportCredentials + Dialer can be used to construct a gRPC client (grpc.ClientConn) +// that uses a github.com/zrepl/zrepl/daemon/transport/connect.Connecter to connect to a server. +// +// The adaptors exposed by this package must be used together, and panic if they are not. +// See package grpchelper for a more restrictive but safe example on how the adaptors should be composed. +package grpcclientidentity + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/transport" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +type Logger = logger.Logger + +type GRPCDialFunction = func(string, time.Duration) (net.Conn, error) + +func NewDialer(logger Logger, connecter transport.Connecter) GRPCDialFunction { + return func(s string, duration time.Duration) (conn net.Conn, e error) { + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + nc, err := connecter.Connect(ctx) + // TODO find better place (callback from gRPC?) where to log errors + // we want the users to know, though + if err != nil { + logger.WithError(err).Error("cannot connect") + } + return nc, err + } +} + +type authConnAuthType struct { + clientIdentity string +} + +func (authConnAuthType) AuthType() string { + return "AuthConn" +} + +type connecterAuthType struct{} + +func (connecterAuthType) AuthType() string { + return "connecter" +} + +type transportCredentials struct { + logger Logger +} + +// Use on both sides as ServerOption or ClientOption. +func NewTransportCredentials(log Logger) credentials.TransportCredentials { + if log == nil { + log = logger.NewNullLogger() + } + return &transportCredentials{log} +} + +func (c *transportCredentials) ClientHandshake(ctx context.Context, s string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.logger.WithField("url", s).WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ClientHandshake") + // do nothing, client credential is only for WithInsecure warning to go away + // the authentication is done by the connecter + return rawConn, &connecterAuthType{}, nil +} + +func (c *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.logger.WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ServerHandshake") + authConn, ok := rawConn.(*transport.AuthConn) + if !ok { + panic(fmt.Sprintf("NewTransportCredentials must be used with a listener that returns *transport.AuthConn, got %T", rawConn)) + } + return rawConn, &authConnAuthType{authConn.ClientIdentity()}, nil +} + +func (*transportCredentials) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} // TODO +} + +func (t *transportCredentials) Clone() credentials.TransportCredentials { + var x = *t + return &x +} + +func (*transportCredentials) OverrideServerName(string) error { + panic("not implemented") +} + +func NewInterceptors(logger Logger, clientIdentityKey interface{}) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) { + unary = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + logger.WithField("fullMethod", info.FullMethod).Debug("request") + p, ok := peer.FromContext(ctx) + if !ok { + panic("peer.FromContext expected to return a peer in grpc.UnaryServerInterceptor") + } + logger.WithField("peer", fmt.Sprintf("%v", p)).Debug("peer") + a, ok := p.AuthInfo.(*authConnAuthType) + if !ok { + panic(fmt.Sprintf("NewInterceptors must be used in combination with grpc.NewTransportCredentials, but got auth type %T", p.AuthInfo)) + } + ctx = context.WithValue(ctx, clientIdentityKey, a.clientIdentity) + return handler(ctx, req) + } + stream = nil + return +} diff --git a/rpc/grpcclientidentity/example/grpcauth.proto b/rpc/grpcclientidentity/example/grpcauth.proto new file mode 100644 index 0000000..5adf606 --- /dev/null +++ b/rpc/grpcclientidentity/example/grpcauth.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package pdu; + + +service Greeter { + rpc Greet(GreetRequest) returns (GreetResponse) {} +} + +message GreetRequest { + string name = 1; +} + +message GreetResponse { + string msg = 1; +} \ No newline at end of file diff --git a/rpc/grpcclientidentity/example/main.go b/rpc/grpcclientidentity/example/main.go new file mode 100644 index 0000000..3813683 --- /dev/null +++ b/rpc/grpcclientidentity/example/main.go @@ -0,0 +1,107 @@ +// This package demonstrates how the grpcclientidentity package can be used +// to set up a gRPC greeter service. +package main + +import ( + "context" + "flag" + "fmt" + "os" + "time" + + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/example/pdu" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper" + "github.com/zrepl/zrepl/transport/tcp" +) + +var args struct { + mode string +} + +var log = logger.NewStderrDebugLogger() + +func main() { + flag.StringVar(&args.mode, "mode", "", "client|server") + flag.Parse() + + switch args.mode { + case "client": + client() + case "server": + server() + default: + log.Printf("unknown mode %q") + os.Exit(1) + } +} + +func onErr(err error, format string, args ...interface{}) { + log.WithError(err).Error(fmt.Sprintf("%s: %s", fmt.Sprintf(format, args...), err)) + os.Exit(1) +} + +func client() { + cn, err := tcp.TCPConnecterFromConfig(&config.TCPConnect{ + ConnectCommon: config.ConnectCommon{ + Type: "tcp", + }, + Address: "127.0.0.1:8080", + DialTimeout: 10 * time.Second, + }) + if err != nil { + onErr(err, "build connecter error") + } + + clientConn := grpchelper.ClientConn(cn, log) + defer clientConn.Close() + + // normal usage from here on + + client := pdu.NewGreeterClient(clientConn) + resp, err := client.Greet(context.Background(), &pdu.GreetRequest{Name: "somethingimadeup"}) + if err != nil { + onErr(err, "RPC error") + } + + fmt.Printf("got response:\n\t%s\n", resp.GetMsg()) +} + +const clientIdentityKey = "clientIdentity" + +func server() { + authListenerFactory, err := tcp.TCPListenerFactoryFromConfig(nil, &config.TCPServe{ + ServeCommon: config.ServeCommon{ + Type: "tcp", + }, + Listen: "127.0.0.1:8080", + Clients: map[string]string{ + "127.0.0.1": "localclient", + "::1": "localclient", + }, + }) + if err != nil { + onErr(err, "cannot build listener factory") + } + + log := logger.NewStderrDebugLogger() + + srv, serve, err := grpchelper.NewServer(authListenerFactory, clientIdentityKey, log) + svc := &greeter{"hello "} + pdu.RegisterGreeterServer(srv, svc) + + if err := serve(); err != nil { + onErr(err, "error serving") + } +} + +type greeter struct { + prepend string +} + +func (g *greeter) Greet(ctx context.Context, r *pdu.GreetRequest) (*pdu.GreetResponse, error) { + ci, _ := ctx.Value(clientIdentityKey).(string) + log.WithField("clientIdentity", ci).Info("Greet() request") // show that we got the client identity + return &pdu.GreetResponse{Msg: fmt.Sprintf("%s%s (clientIdentity=%q)", g.prepend, r.GetName(), ci)}, nil +} diff --git a/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go b/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go new file mode 100644 index 0000000..0fbf2cb --- /dev/null +++ b/rpc/grpcclientidentity/example/pdu/grpcauth.pb.go @@ -0,0 +1,193 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: grpcauth.proto + +package pdu + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type GreetRequest struct { + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GreetRequest) Reset() { *m = GreetRequest{} } +func (m *GreetRequest) String() string { return proto.CompactTextString(m) } +func (*GreetRequest) ProtoMessage() {} +func (*GreetRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_1dfba7be0cf69353, []int{0} +} + +func (m *GreetRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GreetRequest.Unmarshal(m, b) +} +func (m *GreetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GreetRequest.Marshal(b, m, deterministic) +} +func (m *GreetRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GreetRequest.Merge(m, src) +} +func (m *GreetRequest) XXX_Size() int { + return xxx_messageInfo_GreetRequest.Size(m) +} +func (m *GreetRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GreetRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GreetRequest proto.InternalMessageInfo + +func (m *GreetRequest) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +type GreetResponse struct { + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GreetResponse) Reset() { *m = GreetResponse{} } +func (m *GreetResponse) String() string { return proto.CompactTextString(m) } +func (*GreetResponse) ProtoMessage() {} +func (*GreetResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_1dfba7be0cf69353, []int{1} +} + +func (m *GreetResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GreetResponse.Unmarshal(m, b) +} +func (m *GreetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GreetResponse.Marshal(b, m, deterministic) +} +func (m *GreetResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_GreetResponse.Merge(m, src) +} +func (m *GreetResponse) XXX_Size() int { + return xxx_messageInfo_GreetResponse.Size(m) +} +func (m *GreetResponse) XXX_DiscardUnknown() { + xxx_messageInfo_GreetResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_GreetResponse proto.InternalMessageInfo + +func (m *GreetResponse) GetMsg() string { + if m != nil { + return m.Msg + } + return "" +} + +func init() { + proto.RegisterType((*GreetRequest)(nil), "pdu.GreetRequest") + proto.RegisterType((*GreetResponse)(nil), "pdu.GreetResponse") +} + +func init() { proto.RegisterFile("grpcauth.proto", fileDescriptor_1dfba7be0cf69353) } + +var fileDescriptor_1dfba7be0cf69353 = []byte{ + // 137 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x2a, 0x48, + 0x4e, 0x2c, 0x2d, 0xc9, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2e, 0x48, 0x29, 0x55, + 0x52, 0xe2, 0xe2, 0x71, 0x2f, 0x4a, 0x4d, 0x2d, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, + 0x12, 0xe2, 0x62, 0xc9, 0x4b, 0xcc, 0x4d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3, + 0x95, 0x14, 0xb9, 0x78, 0xa1, 0x6a, 0x8a, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0x04, 0xb8, 0x98, + 0x73, 0x8b, 0xd3, 0xa1, 0x6a, 0x40, 0x4c, 0x23, 0x6b, 0x2e, 0x76, 0xb0, 0x92, 0xd4, 0x22, 0x21, + 0x03, 0x2e, 0x56, 0x30, 0x53, 0x48, 0x50, 0xaf, 0x20, 0xa5, 0x54, 0x0f, 0xd9, 0x74, 0x29, 0x21, + 0x64, 0x21, 0x88, 0x61, 0x4a, 0x0c, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0xa8, 0x53, 0x2f, 0x4c, 0xa1, 0x00, 0x00, 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// GreeterClient is the client API for Greeter service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type GreeterClient interface { + Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) +} + +type greeterClient struct { + cc *grpc.ClientConn +} + +func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { + return &greeterClient{cc} +} + +func (c *greeterClient) Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) { + out := new(GreetResponse) + err := c.cc.Invoke(ctx, "/pdu.Greeter/Greet", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// GreeterServer is the server API for Greeter service. +type GreeterServer interface { + Greet(context.Context, *GreetRequest) (*GreetResponse, error) +} + +func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { + s.RegisterService(&_Greeter_serviceDesc, srv) +} + +func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GreetRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GreeterServer).Greet(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/pdu.Greeter/Greet", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GreeterServer).Greet(ctx, req.(*GreetRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _Greeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "pdu.Greeter", + HandlerType: (*GreeterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Greet", + Handler: _Greeter_Greet_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "grpcauth.proto", +} diff --git a/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go b/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go new file mode 100644 index 0000000..7e46681 --- /dev/null +++ b/rpc/grpcclientidentity/grpchelper/authlistener_grpc_adaptor_wrapper.go @@ -0,0 +1,76 @@ +// Package grpchelper wraps the adaptors implemented by package grpcclientidentity into a less flexible API +// which, however, ensures that the individual adaptor primitive's expectations are met and hence do not panic. +package grpchelper + +import ( + "context" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/grpcclientidentity" + "github.com/zrepl/zrepl/rpc/netadaptor" + "github.com/zrepl/zrepl/transport" +) + +// The following constants are relevant for interoperability. +// We use the same values for client & server, because zrepl is more +// symmetrical ("one source, one sink") instead of the typical +// gRPC scenario ("many clients, single server") +const ( + StartKeepalivesAfterInactivityDuration = 5 * time.Second + KeepalivePeerTimeout = 10 * time.Second +) + +type Logger = logger.Logger + +// ClientConn is an easy-to-use wrapper around the Dialer and TransportCredentials interface +// to produce a grpc.ClientConn +func ClientConn(cn transport.Connecter, log Logger) *grpc.ClientConn { + ka := grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: StartKeepalivesAfterInactivityDuration, + Timeout: KeepalivePeerTimeout, + PermitWithoutStream: true, + }) + dialerOption := grpc.WithDialer(grpcclientidentity.NewDialer(log, cn)) + cred := grpc.WithTransportCredentials(grpcclientidentity.NewTransportCredentials(log)) + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, "doesn't matter done by dialer", dialerOption, cred, ka) + if err != nil { + log.WithError(err).Error("cannot create gRPC client conn (non-blocking)") + // It's ok to panic here: the we call grpc.DialContext without the + // (grpc.WithBlock) dial option, and at the time of writing, the grpc + // docs state that no connection attempt is made in that case. + // Hence, any error that occurs is due to DialOptions or similar, + // and thus indicative of an implementation error. + panic(err) + } + return cc +} + +// NewServer is a convenience interface around the TransportCredentials and Interceptors interface. +func NewServer(authListenerFactory transport.AuthenticatedListenerFactory, clientIdentityKey interface{}, logger grpcclientidentity.Logger) (srv *grpc.Server, serve func() error, err error) { + ka := grpc.KeepaliveParams(keepalive.ServerParameters{ + Time: StartKeepalivesAfterInactivityDuration, + Timeout: KeepalivePeerTimeout, + }) + tcs := grpcclientidentity.NewTransportCredentials(logger) + unary, stream := grpcclientidentity.NewInterceptors(logger, clientIdentityKey) + srv = grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream), ka) + + serve = func() error { + l, err := authListenerFactory() + if err != nil { + return err + } + if err := srv.Serve(netadaptor.New(l, logger)); err != nil { + return err + } + return nil + } + + return srv, serve, nil +} diff --git a/rpc/netadaptor/authlistener_netlistener_adaptor.go b/rpc/netadaptor/authlistener_netlistener_adaptor.go new file mode 100644 index 0000000..cc9ae70 --- /dev/null +++ b/rpc/netadaptor/authlistener_netlistener_adaptor.go @@ -0,0 +1,102 @@ +// Package netadaptor implements an adaptor from +// transport.AuthenticatedListener to net.Listener. +// +// In contrast to transport.AuthenticatedListener, +// net.Listener is commonly expected (e.g. by net/http.Server.Serve), +// to return errors that fulfill the Temporary interface: +// interface Temporary { Temporary() bool } +// Common behavior of packages consuming net.Listener is to return +// from the serve-loop if an error returned by Accept is not Temporary, +// i.e., does not implement the interface or is !net.Error.Temporary(). +// +// The zrepl transport infrastructure was written with the +// idea that Accept() may return any kind of error, and the consumer +// would just log the error and continue calling Accept(). +// We have to adapt these listeners' behavior to the expectations +// of net/http.Server. +// +// Hence, Listener does not return an error at all but blocks the +// caller of Accept() until we get a (successfully authenticated) +// connection without errors from the transport. +// Accept errors returned from the transport are logged as errors +// to the logger passed on initialization. +package netadaptor + +import ( + "context" + "fmt" + "github.com/zrepl/zrepl/logger" + "net" + "github.com/zrepl/zrepl/transport" +) + +type Logger = logger.Logger + +type acceptReq struct { + callback chan net.Conn +} + +type Listener struct { + al transport.AuthenticatedListener + logger Logger + accepts chan acceptReq + stop chan struct{} +} + +// Consume the given authListener and wrap it as a *Listener, which implements net.Listener. +// The returned net.Listener must be closed. +// The wrapped authListener is closed when the returned net.Listener is closed. +func New(authListener transport.AuthenticatedListener, l Logger) *Listener { + if l == nil { + l = logger.NewNullLogger() + } + a := &Listener{ + al: authListener, + logger: l, + accepts: make(chan acceptReq), + stop: make(chan struct{}), + } + go a.handleAccept() + return a +} + +// The returned net.Conn is guaranteed to be *transport.AuthConn, i.e., the type of connection +// returned by the wrapped transport.AuthenticatedListener. +func (a Listener) Accept() (net.Conn, error) { + req := acceptReq{make(chan net.Conn, 1)} + a.accepts <- req + conn := <-req.callback + return conn, nil +} + +func (a Listener) handleAccept() { + for { + select { + case <-a.stop: + a.logger.Debug("handleAccept stop accepting") + return + case req := <-a.accepts: + for { + a.logger.Debug("accept authListener") + authConn, err := a.al.Accept(context.Background()) + if err != nil { + a.logger.WithError(err).Error("accept error") + continue + } + a.logger.WithField("type", fmt.Sprintf("%T", authConn)). + Debug("accept complete") + req.callback <- authConn + break + } + } + } +} + +func (a Listener) Addr() net.Addr { + return a.al.Addr() +} + +func (a Listener) Close() error { + close(a.stop) + return a.al.Close() +} diff --git a/rpc/rpc_client.go b/rpc/rpc_client.go new file mode 100644 index 0000000..efcaf98 --- /dev/null +++ b/rpc/rpc_client.go @@ -0,0 +1,106 @@ +package rpc + +import ( + "context" + "net" + "time" + + "google.golang.org/grpc" + + "github.com/zrepl/zrepl/replication" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper" + "github.com/zrepl/zrepl/rpc/versionhandshake" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/envconst" + "github.com/zrepl/zrepl/zfs" +) + +// Client implements the active side of a replication setup. +// It satisfies the Endpoint, Sender and Receiver interface defined by package replication. +type Client struct { + dataClient *dataconn.Client + controlClient pdu.ReplicationClient // this the grpc client instance, see constructor + controlConn *grpc.ClientConn + loggers Loggers +} + +var _ replication.Endpoint = &Client{} +var _ replication.Sender = &Client{} +var _ replication.Receiver = &Client{} + +type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) + +// config must be validated, NewClient will panic if it is not valid +func NewClient(cn transport.Connecter, loggers Loggers) *Client { + + cn = versionhandshake.Connecter(cn, envconst.Duration("ZREPL_RPC_CLIENT_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) + + muxedConnecter := mux(cn) + + c := &Client{ + loggers: loggers, + } + grpcConn := grpchelper.ClientConn(muxedConnecter.control, loggers.Control) + + go func() { + for { + state := grpcConn.GetState() + loggers.General.WithField("grpc_state", state.String()).Debug("grpc state change") + grpcConn.WaitForStateChange(context.TODO(), state) + } + }() + c.controlClient = pdu.NewReplicationClient(grpcConn) + c.controlConn = grpcConn + + c.dataClient = dataconn.NewClient(muxedConnecter.data, loggers.Data) + return c +} + +func (c *Client) Close() { + if err := c.controlConn.Close(); err != nil { + c.loggers.General.WithError(err).Error("cannot cloe control connection") + } + // TODO c.dataClient should have Close() +} + +// callers must ensure that the returned io.ReadCloser is closed +// TODO expose dataClient interface to the outside world +func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + // TODO the returned sendStream may return a read error created by the remote side + res, streamCopier, err := c.dataClient.ReqSend(ctx, r) + if err != nil { + return nil, nil, nil + } + if streamCopier == nil { + return res, nil, nil + } + + return res, streamCopier, nil + +} + +func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { + return c.dataClient.ReqRecv(ctx, req, streamCopier) +} + +func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { + return c.controlClient.ListFilesystems(ctx, in) +} + +func (c *Client) ListFilesystemVersions(ctx context.Context, in *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) { + return c.controlClient.ListFilesystemVersions(ctx, in) +} + +func (c *Client) DestroySnapshots(ctx context.Context, in *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) { + return c.controlClient.DestroySnapshots(ctx, in) +} + +func (c *Client) ReplicationCursor(ctx context.Context, in *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) { + return c.controlClient.ReplicationCursor(ctx, in) +} + +func (c *Client) ResetConnectBackoff() { + c.controlConn.ResetConnectBackoff() +} diff --git a/rpc/rpc_debug.go b/rpc/rpc_debug.go new file mode 100644 index 0000000..a31e1f2 --- /dev/null +++ b/rpc/rpc_debug.go @@ -0,0 +1,20 @@ +package rpc + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_RPC_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "rpc: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/rpc/rpc_doc.go b/rpc/rpc_doc.go new file mode 100644 index 0000000..dbe7dda --- /dev/null +++ b/rpc/rpc_doc.go @@ -0,0 +1,118 @@ +// Package rpc implements zrepl daemon-to-daemon RPC protocol +// on top of a transport provided by package transport. +// The zrepl documentation refers to the client as the +// `active side` and to the server as the `passive side`. +// +// Design Considerations +// +// zrepl has non-standard requirements to remote procedure calls (RPC): +// whereas the coordination of replication (the planning phase) mostly +// consists of regular unary calls, the actual data transfer requires +// a high-throughput, low-overhead solution. +// +// Specifically, the requirements for data transfer is to perform +// a copy of an io.Reader over the wire, such that an io.EOF of the original +// reader corresponds to an io.EOF on the receiving side. +// If any other error occurs while reading from the original io.Reader +// on the sender, the receiver should receive the contents of that error +// in some form (usually as a trailer message) +// A common implementation technique for above data transfer is chunking, +// for example in HTTP: +// https://tools.ietf.org/html/rfc2616#section-3.6.1 +// +// With regards to regular RPCs, gRPC is a popular implementation based +// on protocol buffers and thus code generation. +// gRPC also supports bi-directional streaming RPC, and it is possible +// to implement chunked transfer through the use of streams. +// +// For zrepl however, both HTTP and manually implemented chunked transfer +// using gRPC were found to have significant CPU overhead at transfer +// speeds to be expected even with hobbyist users. +// +// However, it is nice to piggyback on the code generation provided +// by protobuf / gRPC, in particular since the majority of call types +// are regular unary RPCs for which the higher overhead of gRPC is acceptable. +// +// Hence, this package attempts to combine the best of both worlds: +// +// GRPC for Coordination and Dataconn for Bulk Data Transfer +// +// This package's Client uses its transport.Connecter to maintain +// separate control and data connections to the Server. +// The control connection is used by an instance of pdu.ReplicationClient +// whose 'regular' unary RPC calls are re-exported. +// The data connection is used by an instance of dataconn.Client and +// is used for bulk data transfer, namely `Send` and `Receive`. +// +// The following ASCII diagram gives an overview of how the individual +// building blocks are glued together: +// +// +------------+ +// | rpc.Client | +// +------------+ +// | | +// +--------+ +------------+ +// | | +// +---------v-----------+ +--------v------+ +// |pdu.ReplicationClient| |dataconn.Client| +// +---------------------+ +--------v------+ +// | label: label: | +// | zrepl_control zrepl_data | +// +--------+ +------------+ +// | | +// +--v---------v---+ +// | transportmux | +// +-------+--------+ +// | uses +// +-------v--------+ +// |versionhandshake| +// +-------+--------+ +// | uses +// +------v------+ +// | transport | +// +------+------+ +// | +// NETWORK +// | +// +------+------+ +// | transport | +// +------^------+ +// | uses +// +-------+--------+ +// |versionhandshake| +// +-------^--------+ +// | uses +// +-------+--------+ +// | transportmux | +// +--^--------^----+ +// | | +// +--------+ --------------+ --- +// | | | +// | label: label: | | +// | zrepl_control zrepl_data | | +// +-----+----+ +-----------+---+ | +// |netadaptor| |dataconn.Server| | rpc.Server +// | + | +------+--------+ | +// |grpcclient| | | +// |identity | | | +// +-----+----+ | | +// | | | +// +---------v-----------+ | | +// |pdu.ReplicationServer| | | +// +---------+-----------+ | | +// | | --- +// +----------+ +------------+ +// | | +// +---v--v-----+ +// | Handler | +// +------------+ +// (usually endpoint.{Sender,Receiver}) +// +// +package rpc + +// edit trick for the ASCII art above: +// - remove the comments // +// - vim: set virtualedit+=all +// - vim: set ft=text + diff --git a/rpc/rpc_logging.go b/rpc/rpc_logging.go new file mode 100644 index 0000000..e842f6e --- /dev/null +++ b/rpc/rpc_logging.go @@ -0,0 +1,34 @@ +package rpc + +import ( + "context" + + "github.com/zrepl/zrepl/logger" +) + +type Logger = logger.Logger + +type contextKey int + +const ( + contextKeyLoggers contextKey = iota + contextKeyGeneralLogger + contextKeyControlLogger + contextKeyDataLogger +) + +/// All fields must be non-nil +type Loggers struct { + General Logger + Control Logger + Data Logger +} + +func WithLoggers(ctx context.Context, loggers Loggers) context.Context { + ctx = context.WithValue(ctx, contextKeyLoggers, loggers) + return ctx +} + +func GetLoggersOrPanic(ctx context.Context) Loggers { + return ctx.Value(contextKeyLoggers).(Loggers) +} diff --git a/rpc/rpc_mux.go b/rpc/rpc_mux.go new file mode 100644 index 0000000..f70dde7 --- /dev/null +++ b/rpc/rpc_mux.go @@ -0,0 +1,57 @@ +package rpc + +import ( + "context" + "time" + + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/rpc/transportmux" + "github.com/zrepl/zrepl/util/envconst" +) + +type demuxedListener struct { + control, data transport.AuthenticatedListener +} + +const ( + transportmuxLabelControl string = "zrepl_control" + transportmuxLabelData string = "zrepl_data" +) + +func demux(serveCtx context.Context, listener transport.AuthenticatedListener) demuxedListener { + listeners, err := transportmux.Demux( + serveCtx, listener, + []string{transportmuxLabelControl, transportmuxLabelData}, + envconst.Duration("ZREPL_TRANSPORT_DEMUX_TIMEOUT", 10*time.Second), + ) + if err != nil { + // transportmux API guarantees that the returned error can only be due + // to invalid API usage (i.e. labels too long) + panic(err) + } + return demuxedListener{ + control: listeners[transportmuxLabelControl], + data: listeners[transportmuxLabelData], + } +} + +type muxedConnecter struct { + control, data transport.Connecter +} + +func mux(rawConnecter transport.Connecter) muxedConnecter { + muxedConnecters, err := transportmux.MuxConnecter( + rawConnecter, + []string{transportmuxLabelControl, transportmuxLabelData}, + envconst.Duration("ZREPL_TRANSPORT_MUX_TIMEOUT", 10*time.Second), + ) + if err != nil { + // transportmux API guarantees that the returned error can only be due + // to invalid API usage (i.e. labels too long) + panic(err) + } + return muxedConnecter{ + control: muxedConnecters[transportmuxLabelControl], + data: muxedConnecters[transportmuxLabelData], + } +} diff --git a/rpc/rpc_server.go b/rpc/rpc_server.go new file mode 100644 index 0000000..3abbc18 --- /dev/null +++ b/rpc/rpc_server.go @@ -0,0 +1,119 @@ +package rpc + +import ( + "context" + "time" + + "google.golang.org/grpc" + + "github.com/zrepl/zrepl/endpoint" + "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/grpcclientidentity" + "github.com/zrepl/zrepl/rpc/netadaptor" + "github.com/zrepl/zrepl/rpc/versionhandshake" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/envconst" +) + +type Handler interface { + pdu.ReplicationServer + dataconn.Handler +} + +type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error) + +// Server abstracts the accept and request routing infrastructure for the +// passive side of a replication setup. +type Server struct { + logger Logger + handler Handler + controlServer *grpc.Server + controlServerServe serveFunc + dataServer *dataconn.Server + dataServerServe serveFunc +} + +type serverContextKey int + +type HandlerContextInterceptor func(ctx context.Context) context.Context + +// config must be valid (use its Validate function). +func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server { + + // setup control server + tcs := grpcclientidentity.NewTransportCredentials(loggers.Control) // TODO different subsystem for log + unary, stream := grpcclientidentity.NewInterceptors(loggers.Control, endpoint.ClientIdentityKey) + controlServer := grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream)) + pdu.RegisterReplicationServer(controlServer, handler) + controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) { + // give time for graceful stop until deadline expires, then hard stop + go func() { + <-ctx.Done() + if dl, ok := ctx.Deadline(); ok { + go time.AfterFunc(dl.Sub(dl), controlServer.Stop) + } + loggers.Control.Debug("shutting down control server") + controlServer.GracefulStop() + }() + + errOut <- controlServer.Serve(netadaptor.New(controlListener, loggers.Control)) + } + + // setup data server + dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) { + ci := wire.ClientIdentity() + ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci) + if ctxInterceptor != nil { + ctx = ctxInterceptor(ctx) // SHADOWING + } + return ctx, wire + } + dataServer := dataconn.NewServer(dataServerClientIdentitySetter, loggers.Data, handler) + dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) { + dataServer.Serve(ctx, dataListener) + errOut <- nil // TODO bad design of dataServer? + } + + server := &Server{ + logger: loggers.General, + handler: handler, + controlServer: controlServer, + controlServerServe: controlServerServe, + dataServer: dataServer, + dataServerServe: dataServerServe, + } + + return server +} + +// The context is used for cancellation only. +// Serve never returns an error, it logs them to the Server's logger. +func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) { + ctx, cancel := context.WithCancel(ctx) + + l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second)) + + // it is important that demux's context is cancelled, + // it has background goroutines attached + demuxListener := demux(ctx, l) + + serveErrors := make(chan error, 2) + go s.controlServerServe(ctx, demuxListener.control, serveErrors) + go s.dataServerServe(ctx, demuxListener.data, serveErrors) + select { + case serveErr := <-serveErrors: + s.logger.WithError(serveErr).Error("serve error") + s.logger.Debug("wait for other server to shut down") + cancel() + secondServeErr := <-serveErrors + s.logger.WithError(secondServeErr).Error("serve error") + case <-ctx.Done(): + s.logger.Debug("context cancelled, wait for control and data servers") + cancel() + for i := 0; i < 2; i++ { + <-serveErrors + } + s.logger.Debug("control and data server shut down, returning from Serve") + } +} diff --git a/rpc/transportmux/transportmux.go b/rpc/transportmux/transportmux.go new file mode 100644 index 0000000..cb6f7ca --- /dev/null +++ b/rpc/transportmux/transportmux.go @@ -0,0 +1,205 @@ +// Package transportmux wraps a transport.{Connecter,AuthenticatedListener} +// to distinguish different connection types based on a label +// sent from client to server on connection establishment. +// +// Labels are plain text and fixed length. +package transportmux + +import ( + "context" + "io" + "net" + "time" + "fmt" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/transport" +) + +type contextKey int + +const ( + contextKeyLog contextKey = 1 + iota +) + +type Logger = logger.Logger + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLog, log) +} + +func getLog(ctx context.Context) Logger { + if l, ok := ctx.Value(contextKeyLog).(Logger); ok { + return l + } + return logger.NewNullLogger() +} + +type acceptRes struct { + conn *transport.AuthConn + err error +} + +type demuxListener struct { + conns chan acceptRes +} + +func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + res := <-l.conns + return res.conn, res.err +} + +type demuxAddr struct {} + +func (demuxAddr) Network() string { return "demux" } +func (demuxAddr) String() string { return "demux" } + +func (l *demuxListener) Addr() net.Addr { + return demuxAddr{} +} + +func (l *demuxListener) Close() error { return nil } // TODO + +// Exact length of a label in bytes (0-byte padded if it is shorter). +// This is a protocol constant, changing it breaks the wire protocol. +const LabelLen = 64 + +func padLabel(out []byte, label string) (error) { + if len(label) > LabelLen { + return fmt.Errorf("label %q exceeds max length (is %d, max %d)", label, len(label), LabelLen) + } + if len(out) != LabelLen { + panic(fmt.Sprintf("implementation error: %d", out)) + } + labelBytes := []byte(label) + copy(out[:], labelBytes) + return nil +} + +func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, labels []string, timeout time.Duration) (map[string]transport.AuthenticatedListener, error) { + + padded := make(map[[64]byte]*demuxListener, len(labels)) + ret := make(map[string]transport.AuthenticatedListener, len(labels)) + for _, label := range labels { + var labelPadded [LabelLen]byte + err := padLabel(labelPadded[:], label) + if err != nil { + return nil, err + } + if _, ok := padded[labelPadded]; ok { + return nil, fmt.Errorf("duplicate label %q", label) + } + dl := &demuxListener{make(chan acceptRes)} + padded[labelPadded] = dl + ret[label] = dl + } + + // invariant: padded contains same-length, non-duplicate labels + + go func() { + <-ctx.Done() + getLog(ctx).Debug("context cancelled, closing listener") + if err := rawListener.Close(); err != nil { + getLog(ctx).WithError(err).Error("error closing listener") + } + }() + + go func() { + for { + rawConn, err := rawListener.Accept(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + getLog(ctx).WithError(err).Error("accept error") + continue + } + closeConn := func() { + if err := rawConn.Close(); err != nil { + getLog(ctx).WithError(err).Error("cannot close conn") + } + } + + if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil { + getLog(ctx).WithError(err).Error("SetDeadline failed") + closeConn() + continue + } + + var labelBuf [LabelLen]byte + if _, err := io.ReadFull(rawConn, labelBuf[:]); err != nil { + getLog(ctx).WithError(err).Error("error reading label") + closeConn() + continue + } + + demuxListener, ok := padded[labelBuf] + if !ok { + getLog(ctx).WithError(err). + WithField("client_label", fmt.Sprintf("%q", labelBuf)). + Error("unknown client label") + closeConn() + continue + } + + rawConn.SetDeadline(time.Time{}) + // blocking is intentional + demuxListener.conns <- acceptRes{conn: rawConn, err: nil} + } + }() + + return ret, nil +} + +type labeledConnecter struct { + label []byte + transport.Connecter +} + +func (c labeledConnecter) Connect(ctx context.Context) (transport.Wire, error) { + conn, err := c.Connecter.Connect(ctx) + if err != nil { + return nil, err + } + closeConn := func(why error) { + getLog(ctx).WithField("reason", why.Error()).Debug("closing connection") + if err := conn.Close(); err != nil { + getLog(ctx).WithError(err).Error("error closing connection after label write error") + } + } + + if dl, ok := ctx.Deadline(); ok { + defer conn.SetDeadline(time.Time{}) + if err := conn.SetDeadline(dl); err != nil { + closeConn(err) + return nil, err + } + } + n, err := conn.Write(c.label) + if err != nil { + closeConn(err) + return nil, err + } + if n != len(c.label) { + closeConn(fmt.Errorf("short label write")) + return nil, io.ErrShortWrite + } + return conn, nil +} + +func MuxConnecter(rawConnecter transport.Connecter, labels []string, timeout time.Duration) (map[string]transport.Connecter, error) { + ret := make(map[string]transport.Connecter, len(labels)) + for _, label := range labels { + var paddedLabel [LabelLen]byte + if err := padLabel(paddedLabel[:], label); err != nil { + return nil, err + } + lc := &labeledConnecter{paddedLabel[:], rawConnecter} + if _, ok := ret[label]; ok { + return nil, fmt.Errorf("duplicate label %q", label) + } + ret[label] = lc + } + return ret, nil +} + diff --git a/rpc/versionhandshake/versionhandshake.go b/rpc/versionhandshake/versionhandshake.go new file mode 100644 index 0000000..3864868 --- /dev/null +++ b/rpc/versionhandshake/versionhandshake.go @@ -0,0 +1,181 @@ +// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener} +// to add an exchange of protocol version information on connection establishment. +// +// The protocol version information (banner) is plain text, thus making it +// easy to diagnose issues with standard tools. +package versionhandshake + +import ( + "bytes" + "fmt" + "io" + "net" + "strings" + "time" + "unicode/utf8" +) + +type HandshakeMessage struct { + ProtocolVersion int + Extensions []string +} + +// A HandshakeError describes what went wrong during the handshake. +// It implements net.Error and is always temporary. +type HandshakeError struct { + msg string + // If not nil, the underlying IO error that caused the handshake to fail. + IOError error +} + +var _ net.Error = &HandshakeError{} + +func (e HandshakeError) Error() string { return e.msg } + +// Always true to enable usage in a net.Listener. +func (e HandshakeError) Temporary() bool { return true } + +// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value. +// Otherwise false. +func (e HandshakeError) Timeout() bool { + if neterr, ok := e.IOError.(net.Error); ok { + return neterr.Timeout() + } + return false +} + +func hsErr(format string, args... interface{}) *HandshakeError { + return &HandshakeError{msg: fmt.Sprintf(format, args...)} +} + +func hsIOErr(err error, format string, args... interface{}) *HandshakeError { + return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)} +} + +// MaxProtocolVersion is the maximum allowed protocol version. +// This is a protocol constant, changing it may break the wire format. +const MaxProtocolVersion = 9999 + +// Only returns *HandshakeError as error. +func (m *HandshakeMessage) Encode() ([]byte, error) { + if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion { + return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion)) + } + if len(m.Extensions) >= MaxProtocolVersion { + return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion)) + } + // EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions + var extensions strings.Builder + for i, ext := range m.Extensions { + if strings.ContainsAny(ext, "\n") { + return nil, hsErr("Extension #%d contains forbidden newline character", i) + } + if !utf8.ValidString(ext) { + return nil, hsErr("Extension #%d is not valid UTF-8", i) + } + extensions.WriteString(ext) + extensions.WriteString("\n") + } + withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s", + m.ProtocolVersion, len(m.Extensions), extensions.String()) + withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen) + return []byte(withLen), nil +} + +func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error { + var lenAndSpace [11]byte + if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil { + return hsIOErr(err, "error reading protocol banner length: %s", err) + } + if !utf8.Valid(lenAndSpace[:]) { + return hsErr("invalid start of handshake message: not valid UTF-8") + } + var followLen int + n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen) + if n != 1 || err != nil { + return hsErr("could not parse handshake message length") + } + if followLen > maxLen { + return hsErr("handshake message length exceeds max length (%d vs %d)", + followLen, maxLen) + } + + var buf bytes.Buffer + _, err = io.Copy(&buf, io.LimitReader(r, int64(followLen))) + if err != nil { + return hsIOErr(err, "error reading protocol banner body: %s", err) + } + + var ( + protoVersion, extensionCount int + ) + n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n", + &protoVersion, &extensionCount) + if n != 2 || err != nil { + return hsErr("could not parse handshake message: %s", err) + } + if protoVersion < 1 { + return hsErr("invalid protocol version %q", protoVersion) + } + m.ProtocolVersion = protoVersion + + if extensionCount < 0 { + return hsErr("invalid extension count %q", extensionCount) + } + if extensionCount == 0 { + if buf.Len() != 0 { + return hsErr("unexpected data trailing after header") + } + m.Extensions = nil + return nil + } + s := buf.String() + if strings.Count(s, "\n") != extensionCount { + return hsErr("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount) + } + exts := strings.Split(s, "\n") + if exts[len(exts)-1] != "" { + return hsErr("unexpected data trailing after last extension newline") + } + m.Extensions = exts[0:len(exts)-1] + + return nil +} + +func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error { + // current protocol version is hardcoded here + return DoHandshakeVersion(conn, deadline, 1) +} + +const HandshakeMessageMaxLen = 16 * 4096 + +func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error { + ours := HandshakeMessage{ + ProtocolVersion: version, + Extensions: nil, + } + hsb, err := ours.Encode() + if err != nil { + return hsErr("could not encode protocol banner: %s", err) + } + + defer conn.SetDeadline(time.Time{}) + conn.SetDeadline(deadline) + _, err = io.Copy(conn, bytes.NewBuffer(hsb)) + if err != nil { + return hsErr("could not send protocol banner: %s", err) + } + + theirs := HandshakeMessage{} + if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil { + return hsErr("could not decode protocol banner: %s", err) + } + + if theirs.ProtocolVersion != ours.ProtocolVersion { + return hsErr("protocol versions do not match: ours is %d, theirs is %d", + ours.ProtocolVersion, theirs.ProtocolVersion) + } + // ignore extensions, we don't use them + + return nil +} diff --git a/daemon/transport/handshake_test.go b/rpc/versionhandshake/versionhandshake_test.go similarity index 99% rename from daemon/transport/handshake_test.go rename to rpc/versionhandshake/versionhandshake_test.go index d1c72b4..dd27c9d 100644 --- a/daemon/transport/handshake_test.go +++ b/rpc/versionhandshake/versionhandshake_test.go @@ -1,4 +1,4 @@ -package transport +package versionhandshake import ( "bytes" diff --git a/rpc/versionhandshake/versionhandshake_transport_wrappers.go b/rpc/versionhandshake/versionhandshake_transport_wrappers.go new file mode 100644 index 0000000..660215e --- /dev/null +++ b/rpc/versionhandshake/versionhandshake_transport_wrappers.go @@ -0,0 +1,66 @@ +package versionhandshake + +import ( + "context" + "net" + "time" + "github.com/zrepl/zrepl/transport" +) + +type HandshakeConnecter struct { + connecter transport.Connecter + timeout time.Duration +} + +func (c HandshakeConnecter) Connect(ctx context.Context) (transport.Wire, error) { + conn, err := c.connecter.Connect(ctx) + if err != nil { + return nil, err + } + dl, ok := ctx.Deadline() + if !ok { + dl = time.Now().Add(c.timeout) + } + if err := DoHandshakeCurrentVersion(conn, dl); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func Connecter(connecter transport.Connecter, timeout time.Duration) HandshakeConnecter { + return HandshakeConnecter{ + connecter: connecter, + timeout: timeout, + } +} + +// wrapper type that performs a a protocol version handshake before returning the connection +type HandshakeListener struct { + l transport.AuthenticatedListener + timeout time.Duration +} + +func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() } + +func (l HandshakeListener) Close() error { return l.l.Close() } + +func (l HandshakeListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + conn, err := l.l.Accept(ctx) + if err != nil { + return nil, err + } + dl, ok := ctx.Deadline() + if !ok { + dl = time.Now().Add(l.timeout) // shadowing + } + if err := DoHandshakeCurrentVersion(conn, dl); err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func Listener(l transport.AuthenticatedListener, timeout time.Duration) transport.AuthenticatedListener { + return HandshakeListener{l, timeout} +} diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index ffe6094..07a6669 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -25,12 +25,13 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) { } type ClientAuthListener struct { - l net.Listener + l *net.TCPListener + c *tls.Config handshakeTimeout time.Duration } func NewClientAuthListener( - l net.Listener, ca *x509.CertPool, serverCert tls.Certificate, + l *net.TCPListener, ca *x509.CertPool, serverCert tls.Certificate, handshakeTimeout time.Duration) *ClientAuthListener { if ca == nil { @@ -40,30 +41,35 @@ func NewClientAuthListener( panic(serverCert) } - tlsConf := tls.Config{ + tlsConf := &tls.Config{ Certificates: []tls.Certificate{serverCert}, ClientCAs: ca, ClientAuth: tls.RequireAndVerifyClientCert, PreferServerCipherSuites: true, KeyLogWriter: keylogFromEnv(), } - l = tls.NewListener(l, &tlsConf) return &ClientAuthListener{ l, + tlsConf, handshakeTimeout, } } -func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { - c, err = l.l.Accept() +// Accept() accepts a connection from the *net.TCPListener passed to the constructor +// and sets up the TLS connection, including handshake and peer CommmonName validation +// within the specified handshakeTimeout. +// +// It returns both the raw TCP connection (tcpConn) and the TLS connection (tlsConn) on top of it. +// Access to the raw tcpConn might be necessary if CloseWrite semantics are desired: +// tlsConn.CloseWrite does NOT call tcpConn.CloseWrite, hence we provide access to tcpConn to +// allow the caller to do this by themselves. +func (l *ClientAuthListener) Accept() (tcpConn *net.TCPConn, tlsConn *tls.Conn, clientCN string, err error) { + tcpConn, err = l.l.AcceptTCP() if err != nil { - return nil, "", err - } - tlsConn, ok := c.(*tls.Conn) - if !ok { - return c, "", err + return nil, nil, "", err } + tlsConn = tls.Server(tcpConn, l.c) var ( cn string peerCerts []*x509.Certificate @@ -82,10 +88,11 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { goto CloseAndErr } cn = peerCerts[0].Subject.CommonName - return c, cn, nil + return tcpConn, tlsConn, cn, nil CloseAndErr: - c.Close() - return nil, "", err + // unlike CloseWrite, Close on *tls.Conn actually closes the underlying connection + tlsConn.Close() // TODO log error + return nil, nil, "", err } func (l *ClientAuthListener) Addr() net.Addr { diff --git a/transport/fromconfig/transport_fromconfig.go b/transport/fromconfig/transport_fromconfig.go new file mode 100644 index 0000000..0aa1426 --- /dev/null +++ b/transport/fromconfig/transport_fromconfig.go @@ -0,0 +1,58 @@ +// Package fromconfig instantiates transports based on zrepl config structures +// (see package config). +package fromconfig + +import ( + "fmt" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/transport/local" + "github.com/zrepl/zrepl/transport/ssh" + "github.com/zrepl/zrepl/transport/tcp" + "github.com/zrepl/zrepl/transport/tls" +) + +func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory,error) { + + var ( + l transport.AuthenticatedListenerFactory + err error + ) + switch v := in.Ret.(type) { + case *config.TCPServe: + l, err = tcp.TCPListenerFactoryFromConfig(g, v) + case *config.TLSServe: + l, err = tls.TLSListenerFactoryFromConfig(g, v) + case *config.StdinserverServer: + l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v) + case *config.LocalServe: + l, err = local.LocalListenerFactoryFromConfig(g, v) + default: + return nil, errors.Errorf("internal error: unknown serve type %T", v) + } + + return l, err +} + + +func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) { + var ( + connecter transport.Connecter + err error + ) + switch v := in.Ret.(type) { + case *config.SSHStdinserverConnect: + connecter, err = ssh.SSHStdinserverConnecterFromConfig(v) + case *config.TCPConnect: + connecter, err = tcp.TCPConnecterFromConfig(v) + case *config.TLSConnect: + connecter, err = tls.TLSConnecterFromConfig(v) + case *config.LocalConnect: + connecter, err = local.LocalConnecterFromConfig(v) + default: + panic(fmt.Sprintf("implementation error: unknown connecter type %T", v)) + } + + return connecter, err +} diff --git a/daemon/transport/connecter/connect_local.go b/transport/local/connect_local.go similarity index 72% rename from daemon/transport/connecter/connect_local.go rename to transport/local/connect_local.go index 45c3d68..ba390b8 100644 --- a/daemon/transport/connecter/connect_local.go +++ b/transport/local/connect_local.go @@ -1,11 +1,10 @@ -package connecter +package local import ( "context" "fmt" "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/transport/serve" - "net" + "github.com/zrepl/zrepl/transport" ) type LocalConnecter struct { @@ -23,8 +22,8 @@ func LocalConnecterFromConfig(in *config.LocalConnect) (*LocalConnecter, error) return &LocalConnecter{listenerName: in.ListenerName, clientIdentity: in.ClientIdentity}, nil } -func (c *LocalConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - l := serve.GetLocalListener(c.listenerName) +func (c *LocalConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + l := GetLocalListener(c.listenerName) return l.Connect(dialCtx, c.clientIdentity) } diff --git a/daemon/transport/serve/serve_local.go b/transport/local/serve_local.go similarity index 74% rename from daemon/transport/serve/serve_local.go rename to transport/local/serve_local.go index f71ba70..f7e42aa 100644 --- a/daemon/transport/serve/serve_local.go +++ b/transport/local/serve_local.go @@ -1,4 +1,4 @@ -package serve +package local import ( "context" @@ -7,6 +7,7 @@ import ( "github.com/zrepl/zrepl/util/socketpair" "net" "sync" + "github.com/zrepl/zrepl/transport" ) var localListeners struct { @@ -39,7 +40,7 @@ type connectRequest struct { } type connectResult struct { - conn net.Conn + conn transport.Wire err error } @@ -54,7 +55,7 @@ func newLocalListener() *LocalListener { } // Connect to the LocalListener from a client with identity clientIdentity -func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn net.Conn, err error) { +func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn transport.Wire, err error) { // place request req := connectRequest{ @@ -89,21 +90,14 @@ func (a localAddr) String() string { return a.S } func (l *LocalListener) Addr() (net.Addr) { return localAddr{""} } -type localConn struct { - net.Conn - clientIdentity string -} - -func (l localConn) ClientIdentity() string { return l.clientIdentity } - -func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { +func (l *LocalListener) Accept(ctx context.Context) (*transport.AuthConn, error) { respondToRequest := func(req connectRequest, res connectResult) (err error) { - getLogger(ctx). + transport.GetLogger(ctx). WithField("res.conn", res.conn).WithField("res.err", res.err). Debug("responding to client request") defer func() { errv := recover() - getLogger(ctx).WithField("recover_err", errv). + transport.GetLogger(ctx).WithField("recover_err", errv). Debug("panic on send to client callback, likely a legitimate client-side timeout") }() select { @@ -116,7 +110,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return err } - getLogger(ctx).Debug("waiting for local client connect requests") + transport.GetLogger(ctx).Debug("waiting for local client connect requests") var req connectRequest select { case req = <-l.connects: @@ -124,7 +118,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return nil, ctx.Err() } - getLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request") + transport.GetLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request") if req.clientIdentity == "" { res := connectResult{nil, fmt.Errorf("client identity must not be empty")} if err := respondToRequest(req, res); err != nil { @@ -133,31 +127,31 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) { return nil, fmt.Errorf("client connected with empty client identity") } - getLogger(ctx).Debug("creating socketpair") + transport.GetLogger(ctx).Debug("creating socketpair") left, right, err := socketpair.SocketPair() if err != nil { res := connectResult{nil, fmt.Errorf("server error: %s", err)} if respErr := respondToRequest(req, res); respErr != nil { // returning the socketpair error properly is more important than the error sent to the client - getLogger(ctx).WithError(respErr).Error("error responding to client") + transport.GetLogger(ctx).WithError(respErr).Error("error responding to client") } return nil, err } - getLogger(ctx).Debug("responding with left side of socketpair") + transport.GetLogger(ctx).Debug("responding with left side of socketpair") res := connectResult{left, nil} if err := respondToRequest(req, res); err != nil { - getLogger(ctx).WithError(err).Error("error responding to client") + transport.GetLogger(ctx).WithError(err).Error("error responding to client") if err := left.Close(); err != nil { - getLogger(ctx).WithError(err).Error("cannot close left side of socketpair") + transport.GetLogger(ctx).WithError(err).Error("cannot close left side of socketpair") } if err := right.Close(); err != nil { - getLogger(ctx).WithError(err).Error("cannot close right side of socketpair") + transport.GetLogger(ctx).WithError(err).Error("cannot close right side of socketpair") } return nil, err } - return localConn{right, req.clientIdentity}, nil + return transport.NewAuthConn(right, req.clientIdentity), nil } func (l *LocalListener) Close() error { @@ -169,19 +163,13 @@ func (l *LocalListener) Close() error { return nil } -type LocalListenerFactory struct { - listenerName string -} - -func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (f *LocalListenerFactory, err error) { +func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (transport.AuthenticatedListenerFactory,error) { if in.ListenerName == "" { return nil, fmt.Errorf("ListenerName must not be empty") } - return &LocalListenerFactory{listenerName: in.ListenerName}, nil + listenerName := in.ListenerName + lf := func() (transport.AuthenticatedListener,error) { + return GetLocalListener(listenerName), nil + } + return lf, nil } - - -func (lf *LocalListenerFactory) Listen() (AuthenticatedListener, error) { - return GetLocalListener(lf.listenerName), nil -} - diff --git a/daemon/transport/connecter/connect_ssh.go b/transport/ssh/connect_ssh.go similarity index 72% rename from daemon/transport/connecter/connect_ssh.go rename to transport/ssh/connect_ssh.go index 7efeec5..d669b88 100644 --- a/daemon/transport/connecter/connect_ssh.go +++ b/transport/ssh/connect_ssh.go @@ -1,13 +1,12 @@ -package connecter +package ssh import ( "context" "github.com/jinzhu/copier" "github.com/pkg/errors" "github.com/problame/go-netssh" - "github.com/problame/go-streamrpc" "github.com/zrepl/zrepl/config" - "net" + "github.com/zrepl/zrepl/transport" "time" ) @@ -22,8 +21,6 @@ type SSHStdinserverConnecter struct { dialTimeout time.Duration } -var _ streamrpc.Connecter = &SSHStdinserverConnecter{} - func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSHStdinserverConnecter, err error) { c = &SSHStdinserverConnecter{ @@ -39,15 +36,7 @@ func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSH } -type netsshConnToConn struct{ *netssh.SSHConn } - -var _ net.Conn = netsshConnToConn{} - -func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil } -func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil } -func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil } - -func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) { +func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { var endpoint netssh.Endpoint if err := copier.Copy(&endpoint, c); err != nil { @@ -62,5 +51,5 @@ func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, er } return nil, err } - return netsshConnToConn{nconn}, nil + return nconn, nil } diff --git a/daemon/transport/serve/serve_stdinserver.go b/transport/ssh/serve_stdinserver.go similarity index 57% rename from daemon/transport/serve/serve_stdinserver.go rename to transport/ssh/serve_stdinserver.go index f02bf20..39bfba8 100644 --- a/daemon/transport/serve/serve_stdinserver.go +++ b/transport/ssh/serve_stdinserver.go @@ -1,50 +1,38 @@ -package serve +package ssh import ( "github.com/problame/go-netssh" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/nethelpers" - "io" + "github.com/zrepl/zrepl/transport" + "fmt" "net" "path" - "time" "context" "github.com/pkg/errors" "sync/atomic" ) -type StdinserverListenerFactory struct { - ClientIdentities []string - Sockdir string -} - -func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) { +func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory,error) { for _, ci := range in.ClientIdentities { - if err := ValidateClientIdentity(ci); err != nil { + if err := transport.ValidateClientIdentity(ci); err != nil { return nil, errors.Wrapf(err, "invalid client identity %q", ci) } } - f = &multiStdinserverListenerFactory{ - ClientIdentities: in.ClientIdentities, - Sockdir: g.Serve.StdinServer.SockDir, + clientIdentities := in.ClientIdentities + sockdir := g.Serve.StdinServer.SockDir + + lf := func() (transport.AuthenticatedListener,error) { + return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities) } - return -} - -type multiStdinserverListenerFactory struct { - ClientIdentities []string - Sockdir string -} - -func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) { - return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities) + return lf, nil } type multiStdinserverAcceptRes struct { - conn AuthenticatedConn + conn *transport.AuthConn err error } @@ -78,7 +66,7 @@ func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) return &MultiStdinserverListener{listeners: listeners}, nil } -func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){ +func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error){ if m.accepts == nil { m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners)) @@ -97,8 +85,22 @@ func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedCon } -func (m *MultiStdinserverListener) Addr() (net.Addr) { - return netsshAddr{} +type multiListenerAddr struct { + clients []string +} + +func (multiListenerAddr) Network() string { return "netssh" } + +func (l multiListenerAddr) String() string { + return fmt.Sprintf("netssh:clients=%v", l.clients) +} + +func (m *MultiStdinserverListener) Addr() net.Addr { + cis := make([]string, len(m.listeners)) + for i := range cis { + cis[i] = m.listeners[i].clientIdentity + } + return multiListenerAddr{cis} } func (m *MultiStdinserverListener) Close() error { @@ -118,41 +120,28 @@ type stdinserverListener struct { clientIdentity string } -func (l stdinserverListener) Addr() net.Addr { - return netsshAddr{} +type listenerAddr struct { + clientIdentity string } -func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) { +func (listenerAddr) Network() string { return "netssh" } + +func (a listenerAddr) String() string { + return fmt.Sprintf("netssh:client=%q", a.clientIdentity) +} + +func (l stdinserverListener) Addr() net.Addr { + return listenerAddr{l.clientIdentity} +} + +func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) { c, err := l.l.Accept() if err != nil { return nil, err } - return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil + return transport.NewAuthConn(c, l.clientIdentity), nil } func (l stdinserverListener) Close() (err error) { return l.l.Close() } - -type netsshAddr struct{} - -func (netsshAddr) Network() string { return "netssh" } -func (netsshAddr) String() string { return "???" } - -type netsshConnToNetConnAdatper struct { - io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn - clientIdentity string -} - -func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity } - -func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} } - -func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} } - -// FIXME log warning once! -func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil } - -func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil } - -func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil } diff --git a/daemon/transport/connecter/connect_tcp.go b/transport/tcp/connect_tcp.go similarity index 54% rename from daemon/transport/connecter/connect_tcp.go rename to transport/tcp/connect_tcp.go index 3d8b77e..1176512 100644 --- a/daemon/transport/connecter/connect_tcp.go +++ b/transport/tcp/connect_tcp.go @@ -1,9 +1,11 @@ -package connecter +package tcp import ( "context" - "github.com/zrepl/zrepl/config" "net" + + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" ) type TCPConnecter struct { @@ -19,6 +21,10 @@ func TCPConnecterFromConfig(in *config.TCPConnect) (*TCPConnecter, error) { return &TCPConnecter{in.Address, dialer}, nil } -func (c *TCPConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - return c.dialer.DialContext(dialCtx, "tcp", c.Address) +func (c *TCPConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil } diff --git a/daemon/transport/serve/serve_tcp.go b/transport/tcp/serve_tcp.go similarity index 68% rename from daemon/transport/serve/serve_tcp.go rename to transport/tcp/serve_tcp.go index 957d3b9..a6b8107 100644 --- a/daemon/transport/serve/serve_tcp.go +++ b/transport/tcp/serve_tcp.go @@ -1,17 +1,13 @@ -package serve +package tcp import ( "github.com/zrepl/zrepl/config" "net" "github.com/pkg/errors" "context" + "github.com/zrepl/zrepl/transport" ) -type TCPListenerFactory struct { - address *net.TCPAddr - clientMap *ipMap -} - type ipMapEntry struct { ip net.IP ident string @@ -28,7 +24,7 @@ func ipMapFromConfig(clients map[string]string) (*ipMap, error) { if clientIP == nil { return nil, errors.Errorf("cannot parse client IP %q", clientIPString) } - if err := ValidateClientIdentity(clientIdent); err != nil { + if err := transport.ValidateClientIdentity(clientIdent); err != nil { return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString) } entries = append(entries, ipMapEntry{clientIP, clientIdent}) @@ -45,7 +41,7 @@ func (m *ipMap) Get(ip net.IP) (string, error) { return "", errors.Errorf("no identity mapping for client IP %s", ip) } -func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) { +func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (transport.AuthenticatedListenerFactory, error) { addr, err := net.ResolveTCPAddr("tcp", in.Listen) if err != nil { return nil, errors.Wrap(err, "cannot parse listen address") @@ -54,38 +50,33 @@ func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPLi if err != nil { return nil, errors.Wrap(err, "cannot parse client IP map") } - lf := &TCPListenerFactory{ - address: addr, - clientMap: clientMap, + lf := func() (transport.AuthenticatedListener, error) { + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, err + } + return &TCPAuthListener{l, clientMap}, nil } return lf, nil } -func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) { - l, err := net.ListenTCP("tcp", f.address) - if err != nil { - return nil, err - } - return &TCPAuthListener{l, f.clientMap}, nil -} - type TCPAuthListener struct { *net.TCPListener clientMap *ipMap } -func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { - nc, err := f.TCPListener.Accept() +func (f *TCPAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + nc, err := f.TCPListener.AcceptTCP() if err != nil { return nil, err } clientIP := nc.RemoteAddr().(*net.TCPAddr).IP clientIdent, err := f.clientMap.Get(clientIP) if err != nil { - getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") + transport.GetLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") nc.Close() return nil, err } - return authConn{nc, clientIdent}, nil + return transport.NewAuthConn(nc, clientIdent), nil } diff --git a/daemon/transport/connecter/connect_tls.go b/transport/tls/connect_tls.go similarity index 73% rename from daemon/transport/connecter/connect_tls.go rename to transport/tls/connect_tls.go index a60cb45..ea578d4 100644 --- a/daemon/transport/connecter/connect_tls.go +++ b/transport/tls/connect_tls.go @@ -1,12 +1,14 @@ -package connecter +package tls import ( "context" "crypto/tls" + "net" + "github.com/pkg/errors" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/tlsconf" - "net" + "github.com/zrepl/zrepl/transport" ) type TLSConnecter struct { @@ -38,10 +40,12 @@ func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) { return &TLSConnecter{in.Address, dialer, tlsConfig}, nil } -func (c *TLSConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) { - conn, err = c.dialer.DialContext(dialCtx, "tcp", c.Address) +func (c *TLSConnecter) Connect(dialCtx context.Context) (transport.Wire, error) { + conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address) if err != nil { return nil, err } - return tls.Client(conn, c.tlsConfig), nil + tcpConn := conn.(*net.TCPConn) + tlsConn := tls.Client(conn, c.tlsConfig) + return newWireAdaptor(tlsConn, tcpConn), nil } diff --git a/transport/tls/serve_tls.go b/transport/tls/serve_tls.go new file mode 100644 index 0000000..21aafe4 --- /dev/null +++ b/transport/tls/serve_tls.go @@ -0,0 +1,89 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/tlsconf" + "net" + "time" + "context" +) + +type TLSListenerFactory struct { + address string + clientCA *x509.CertPool + serverCert tls.Certificate + handshakeTimeout time.Duration + clientCNs map[string]struct{} +} + +func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory,error) { + + address := in.Listen + handshakeTimeout := in.HandshakeTimeout + + if in.Ca == "" || in.Cert == "" || in.Key == "" { + return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") + } + + clientCA, err := tlsconf.ParseCAFile(in.Ca) + if err != nil { + return nil, errors.Wrap(err, "cannot parse ca file") + } + + serverCert, err := tls.LoadX509KeyPair(in.Cert, in.Key) + if err != nil { + return nil, errors.Wrap(err, "cannot parse cer/key pair") + } + + clientCNs := make(map[string]struct{}, len(in.ClientCNs)) + for i, cn := range in.ClientCNs { + if err := transport.ValidateClientIdentity(cn); err != nil { + return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn) + } + // dupes are ok fr now + clientCNs[cn] = struct{}{} + } + + lf := func() (transport.AuthenticatedListener, error) { + l, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + tcpL := l.(*net.TCPListener) + tl := tlsconf.NewClientAuthListener(tcpL, clientCA, serverCert, handshakeTimeout) + return &tlsAuthListener{tl, clientCNs}, nil + } + + return lf, nil +} + +type tlsAuthListener struct { + *tlsconf.ClientAuthListener + clientCNs map[string]struct{} +} + +func (l tlsAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + tcpConn, tlsConn, cn, err := l.ClientAuthListener.Accept() + if err != nil { + return nil, err + } + if _, ok := l.clientCNs[cn]; !ok { + if dl, ok := ctx.Deadline(); ok { + defer tlsConn.SetDeadline(time.Time{}) + tlsConn.SetDeadline(dl) + } + if err := tlsConn.Close(); err != nil { + transport.GetLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name") + } + return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, tlsConn.RemoteAddr()) + } + adaptor := newWireAdaptor(tlsConn, tcpConn) + return transport.NewAuthConn(adaptor, cn), nil +} + + diff --git a/transport/tls/tls_wire_adaptor.go b/transport/tls/tls_wire_adaptor.go new file mode 100644 index 0000000..2f03cdd --- /dev/null +++ b/transport/tls/tls_wire_adaptor.go @@ -0,0 +1,47 @@ +package tls + +import ( + "crypto/tls" + "fmt" + "net" + "os" +) + +// adapts a tls.Conn and its underlying net.TCPConn into a valid transport.Wire +type transportWireAdaptor struct { + *tls.Conn + tcpConn *net.TCPConn +} + +func newWireAdaptor(tlsConn *tls.Conn, tcpConn *net.TCPConn) transportWireAdaptor { + return transportWireAdaptor{tlsConn, tcpConn} +} + +// CloseWrite implements transport.Wire.CloseWrite which is different from *tls.Conn.CloseWrite: +// the former requires that the other side observes io.EOF, but *tls.Conn.CloseWrite does not +// close the underlying connection so no io.EOF would be observed. +func (w transportWireAdaptor) CloseWrite() error { + if err := w.Conn.CloseWrite(); err != nil { + // TODO log error + fmt.Fprintf(os.Stderr, "transport/tls.CloseWrite() error: %s\n", err) + } + return w.tcpConn.CloseWrite() +} + +// Close implements transport.Wire.Close which is different from a *tls.Conn.Close: +// At the time of writing (Go 1.11), closing tls.Conn closes the TCP connection immediately, +// which results in io.ErrUnexpectedEOF on the other side. +// We assume that w.Conn has a deadline set for the close, so the CloseWrite will time out if it blocks, +// falling through to the actual Close() +func (w transportWireAdaptor) Close() error { + // var buf [1<<15]byte + // w.Conn.Write(buf[:]) + // CloseWrite will send a TLS alert record down the line which + // in the Go implementation acts like a flush...? + // if err := w.Conn.CloseWrite(); err != nil { + // // TODO log error + // fmt.Fprintf(os.Stderr, "transport/tls.Close() close write error: %s\n", err) + // } + // time.Sleep(1 * time.Second) + return w.Conn.Close() +} diff --git a/transport/transport.go b/transport/transport.go new file mode 100644 index 0000000..e520928 --- /dev/null +++ b/transport/transport.go @@ -0,0 +1,84 @@ +// Package transport defines a common interface for +// network connections that have an associated client identity. +package transport + +import ( + "context" + "errors" + "net" + "syscall" + + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/zfs" +) + +type AuthConn struct { + Wire + clientIdentity string +} + +var _ timeoutconn.SyscallConner = AuthConn{} + +func (a AuthConn) SyscallConn() (rawConn syscall.RawConn, err error) { + scc, ok := a.Wire.(timeoutconn.SyscallConner) + if !ok { + return nil, timeoutconn.SyscallConnNotSupported + } + return scc.SyscallConn() +} + +func NewAuthConn(conn Wire, clientIdentity string) *AuthConn { + return &AuthConn{conn, clientIdentity} +} + +func (c *AuthConn) ClientIdentity() string { + if err := ValidateClientIdentity(c.clientIdentity); err != nil { + panic(err) + } + return c.clientIdentity +} + +// like net.Listener, but with an AuthenticatedConn instead of net.Conn +type AuthenticatedListener interface { + Addr() net.Addr + Accept(ctx context.Context) (*AuthConn, error) + Close() error +} + +type AuthenticatedListenerFactory func() (AuthenticatedListener, error) + +type Wire = timeoutconn.Wire + +type Connecter interface { + Connect(ctx context.Context) (Wire, error) +} + +// A client identity must be a single component in a ZFS filesystem path +func ValidateClientIdentity(in string) (err error) { + path, err := zfs.NewDatasetPath(in) + if err != nil { + return err + } + if path.Length() != 1 { + return errors.New("client identity must be a single path comonent (not empty, no '/')") + } + return nil +} + +type contextKey int + +const contextKeyLog contextKey = 0 + +type Logger = logger.Logger + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLog, log) +} + +func GetLogger(ctx context.Context) Logger { + if log, ok := ctx.Value(contextKeyLog).(Logger); ok { + return log + } + return logger.NewNullLogger() +} diff --git a/util/bytecounter/bytecounter_streamcopier.go b/util/bytecounter/bytecounter_streamcopier.go new file mode 100644 index 0000000..a268a91 --- /dev/null +++ b/util/bytecounter/bytecounter_streamcopier.go @@ -0,0 +1,71 @@ +package bytecounter + +import ( + "io" + "sync/atomic" + + "github.com/zrepl/zrepl/zfs" +) + +// StreamCopier wraps a zfs.StreamCopier, reimplemening +// its interface and counting the bytes written to during copying. +type StreamCopier interface { + zfs.StreamCopier + Count() int64 +} + +// NewStreamCopier wraps sc into a StreamCopier. +// If sc is io.Reader, it is guaranteed that the returned StreamCopier +// implements that interface, too. +func NewStreamCopier(sc zfs.StreamCopier) StreamCopier { + bsc := &streamCopier{sc, 0} + if scr, ok := sc.(io.Reader); ok { + return streamCopierAndReader{bsc, scr} + } else { + return bsc + } +} + +type streamCopier struct { + sc zfs.StreamCopier + count int64 +} + +// proxy writer used by streamCopier +type streamCopierWriter struct { + parent *streamCopier + w io.Writer +} + +func (w streamCopierWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + atomic.AddInt64(&w.parent.count, int64(n)) + return +} + +func (s *streamCopier) Count() int64 { + return atomic.LoadInt64(&s.count) +} + +var _ zfs.StreamCopier = &streamCopier{} + +func (s streamCopier) Close() error { + return s.sc.Close() +} + +func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + ww := streamCopierWriter{s, w} + return s.sc.WriteStreamTo(ww) +} + +// a streamCopier whose underlying sc is an io.Reader +type streamCopierAndReader struct { + *streamCopier + asReader io.Reader +} + +func (scr streamCopierAndReader) Read(p []byte) (int, error) { + n, err := scr.asReader.Read(p) + atomic.AddInt64(&scr.streamCopier.count, int64(n)) + return n, err +} diff --git a/util/bytecounter/bytecounter_streamcopier_test.go b/util/bytecounter/bytecounter_streamcopier_test.go new file mode 100644 index 0000000..29611e0 --- /dev/null +++ b/util/bytecounter/bytecounter_streamcopier_test.go @@ -0,0 +1,38 @@ +package bytecounter + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zrepl/zrepl/zfs" +) + +type mockStreamCopierAndReader struct { + zfs.StreamCopier // to satisfy interface + reads int +} + +func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) { + r.reads++ + return len(p), nil +} + +var _ io.Reader = &mockStreamCopierAndReader{} + +func TestNewStreamCopierReexportsReader(t *testing.T) { + mock := &mockStreamCopierAndReader{} + x := NewStreamCopier(mock) + + r, ok := x.(io.Reader) + if !ok { + t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x) + } + + var buf [23]byte + n, err := r.Read(buf[:]) + assert.True(t, mock.reads == 1) + assert.True(t, n == len(buf)) + assert.NoError(t, err) + assert.True(t, x.Count() == 23) +} diff --git a/zfs/diff.go b/zfs/diff.go index 2b37f6d..52eb84f 100644 --- a/zfs/diff.go +++ b/zfs/diff.go @@ -274,7 +274,7 @@ func ZFSCreatePlaceholderFilesystem(p *DatasetPath) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } diff --git a/zfs/mapping.go b/zfs/mapping.go index 56a85b3..43bb5bb 100644 --- a/zfs/mapping.go +++ b/zfs/mapping.go @@ -9,8 +9,8 @@ type DatasetFilter interface { Filter(p *DatasetPath) (pass bool, err error) } -func ZFSListMapping(filter DatasetFilter) (datasets []*DatasetPath, err error) { - res, err := ZFSListMappingProperties(filter, nil) +func ZFSListMapping(ctx context.Context, filter DatasetFilter) (datasets []*DatasetPath, err error) { + res, err := ZFSListMappingProperties(ctx, filter, nil) if err != nil { return nil, err } @@ -28,7 +28,7 @@ type ZFSListMappingPropertiesResult struct { } // properties must not contain 'name' -func ZFSListMappingProperties(filter DatasetFilter, properties []string) (datasets []ZFSListMappingPropertiesResult, err error) { +func ZFSListMappingProperties(ctx context.Context, filter DatasetFilter, properties []string) (datasets []ZFSListMappingPropertiesResult, err error) { if filter == nil { panic("filter must not be nil") @@ -44,7 +44,7 @@ func ZFSListMappingProperties(filter DatasetFilter, properties []string) (datase copy(newProps[1:], properties) properties = newProps - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() rchan := make(chan ZFSListResult) diff --git a/zfs/zfs.go b/zfs/zfs.go index 56c9c77..8f08a66 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -7,15 +7,22 @@ import ( "errors" "fmt" "io" + "os" "os/exec" "strings" + "sync" + "time" "context" - "github.com/problame/go-rwccmd" "github.com/prometheus/client_golang/prometheus" - "github.com/zrepl/zrepl/util" "regexp" "strconv" + "github.com/zrepl/zrepl/util/envconst" +) + +var ( + ZFSSendPipeCapacityHint = int(envconst.Int64("ZFS_SEND_PIPE_CAPACITY_HINT", 1<<25)) + ZFSRecvPipeCapacityHint = int(envconst.Int64("ZFS_RECV_PIPE_CAPACITY_HINT", 1<<25)) ) type DatasetPath struct { @@ -141,7 +148,7 @@ type ZFSError struct { WaitErr error } -func (e ZFSError) Error() string { +func (e *ZFSError) Error() string { return fmt.Sprintf("zfs exited with error: %s\nstderr:\n%s", e.WaitErr.Error(), e.Stderr) } @@ -187,7 +194,7 @@ func ZFSList(properties []string, zfsArgs ...string) (res [][]string, err error) } if waitErr := cmd.Wait(); waitErr != nil { - err := ZFSError{ + err := &ZFSError{ Stderr: stderr.Bytes(), WaitErr: waitErr, } @@ -227,18 +234,24 @@ func ZFSListChan(ctx context.Context, out chan ZFSListResult, properties []strin } } - cmd, err := rwccmd.CommandContext(ctx, ZFS_BINARY, args, []string{}) + cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) + stdout, err := cmd.StdoutPipe() if err != nil { sendResult(nil, err) return } + // TODO bounded buffer + stderr := bytes.NewBuffer(make([]byte, 0, 1024)) + cmd.Stderr = stderr if err = cmd.Start(); err != nil { sendResult(nil, err) return } - defer cmd.Close() + defer func() { + cmd.Wait() + }() - s := bufio.NewScanner(cmd) + s := bufio.NewScanner(stdout) buf := make([]byte, 1024) // max line length s.Buffer(buf, 0) @@ -252,8 +265,20 @@ func ZFSListChan(ctx context.Context, out chan ZFSListResult, properties []strin return } } + if err := cmd.Wait(); err != nil { + if err, ok := err.(*exec.ExitError); ok { + sendResult(nil, &ZFSError{ + Stderr: stderr.Bytes(), + WaitErr: err, + }) + } else { + sendResult(nil, &ZFSError{WaitErr: err}) + } + return + } if s.Err() != nil { sendResult(nil, s.Err()) + return } return } @@ -314,10 +339,180 @@ func buildCommonSendArgs(fs string, from, to string, token string) ([]string, er return args, nil } +type sendStreamCopier struct { + recorder readErrRecorder +} + +type readErrRecorder struct { + io.ReadCloser + readErr error +} + +type sendStreamCopierError struct { + isReadErr bool // if false, it's a write error + err error +} + +func (e sendStreamCopierError) Error() string { + if e.isReadErr { + return fmt.Sprintf("stream: read error: %s", e.err) + } else { + return fmt.Sprintf("stream: writer error: %s", e.err) + } +} + +func (e sendStreamCopierError) IsReadError() bool { return e.isReadErr } +func (e sendStreamCopierError) IsWriteError() bool { return !e.isReadErr } + +func (r *readErrRecorder) Read(p []byte) (n int, err error) { + n, err = r.ReadCloser.Read(p) + r.readErr = err + return n, err +} + +func newSendStreamCopier(stream io.ReadCloser) *sendStreamCopier { + return &sendStreamCopier{recorder: readErrRecorder{stream, nil}} +} + +func (c *sendStreamCopier) WriteStreamTo(w io.Writer) StreamCopierError { + debug("sendStreamCopier.WriteStreamTo: begin") + _, err := io.Copy(w, &c.recorder) + debug("sendStreamCopier.WriteStreamTo: copy done") + if err != nil { + if c.recorder.readErr != nil { + return sendStreamCopierError{isReadErr: true, err: c.recorder.readErr} + } else { + return sendStreamCopierError{isReadErr: false, err: err} + } + } + return nil +} + +func (c *sendStreamCopier) Read(p []byte) (n int, err error) { + return c.recorder.Read(p) +} + +func (c *sendStreamCopier) Close() error { + return c.recorder.ReadCloser.Close() +} + +func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { + if capacity <= 0 { + panic(fmt.Sprintf("capacity must be positive %v", capacity)) + } + stdoutReader, stdoutWriter, err := os.Pipe() + if err != nil { + return nil, nil, err + } + trySetPipeCapacity(stdoutWriter, capacity) + return stdoutReader, stdoutWriter, nil +} + +type sendStream struct { + cmd *exec.Cmd + kill context.CancelFunc + + closeMtx sync.Mutex + stdoutReader *os.File + opErr error + +} + +func (s *sendStream) Read(p []byte) (n int, err error) { + s.closeMtx.Lock() + opErr := s.opErr + s.closeMtx.Unlock() + if opErr != nil { + return 0, opErr + } + + n, err = s.stdoutReader.Read(p) + if err != nil { + debug("sendStream: read err: %T %s", err, err) + // TODO we assume here that any read error is permanent + // which is most likely the case for a local zfs send + kwerr := s.killAndWait(err) + debug("sendStream: killAndWait n=%v err= %T %s", n, kwerr, kwerr) + // TODO we assume here that any read error is permanent + return n, kwerr + } + return n, err +} + +func (s *sendStream) Close() error { + debug("sendStream: close called") + return s.killAndWait(nil) +} + +func (s *sendStream) killAndWait(precedingReadErr error) error { + + debug("sendStream: killAndWait enter") + defer debug("sendStream: killAndWait leave") + if precedingReadErr == io.EOF { + // give the zfs process a little bit of time to terminate itself + // if it holds this deadline, exitErr will be nil + time.AfterFunc(200*time.Millisecond, s.kill) + } else { + s.kill() + } + + // allow async kills from Close(), that's why we only take the mutex here + s.closeMtx.Lock() + defer s.closeMtx.Unlock() + + if s.opErr != nil { + return s.opErr + } + + waitErr := s.cmd.Wait() + // distinguish between ExitError (which is actually a non-problem for us) + // vs failed wait syscall (for which we give upper layers the chance to retyr) + var exitErr *exec.ExitError + if waitErr != nil { + if ee, ok := waitErr.(*exec.ExitError); ok { + exitErr = ee + } else { + return waitErr + } + } + + // now, after we know the program exited do we close the pipe + var closePipeErr error + if s.stdoutReader != nil { + closePipeErr = s.stdoutReader.Close() + if closePipeErr == nil { + // avoid double-closes in case anything below doesn't work + // and someone calls Close again + s.stdoutReader = nil + } else { + return closePipeErr + } + } + + // we managed to tear things down, no let's give the user some pretty *ZFSError + if exitErr != nil { + s.opErr = &ZFSError{ + Stderr: exitErr.Stderr, + WaitErr: exitErr, + } + } else { + s.opErr = fmt.Errorf("zfs send exited with status code 0") + } + + // detect the edge where we're called from s.Read + // after the pipe EOFed and zfs send exited without errors + // this is actullay the "hot" / nice path + if exitErr == nil && precedingReadErr == io.EOF { + return precedingReadErr + } + + return s.opErr +} + // if token != "", then send -t token is used // otherwise send [-i from] to is used // (if from is "" a full ZFS send is done) -func ZFSSend(ctx context.Context, fs string, from, to string, token string) (stream io.ReadCloser, err error) { +func ZFSSend(ctx context.Context, fs string, from, to string, token string) (streamCopier StreamCopier, err error) { args := make([]string, 0) args = append(args, "send") @@ -328,9 +523,33 @@ func ZFSSend(ctx context.Context, fs string, from, to string, token string) (str } args = append(args, sargs...) - stream, err = util.RunIOCommand(ctx, ZFS_BINARY, args...) + ctx, cancel := context.WithCancel(ctx) + cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) - return + // setup stdout with an os.Pipe to control pipe buffer size + stdoutReader, stdoutWriter, err := pipeWithCapacityHint(ZFSSendPipeCapacityHint) + if err != nil { + cancel() + return nil, err + } + + cmd.Stdout = stdoutWriter + + if err := cmd.Start(); err != nil { + cancel() + stdoutWriter.Close() + stdoutReader.Close() + return nil, err + } + stdoutWriter.Close() + + stream := &sendStream{ + cmd: cmd, + kill: cancel, + stdoutReader: stdoutReader, + } + + return newSendStreamCopier(stream), err } @@ -454,8 +673,26 @@ func ZFSSendDry(fs string, from, to string, token string) (_ *DrySendInfo, err e return &si, nil } +type StreamCopierError interface { + error + IsReadError() bool + IsWriteError() bool +} -func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs ...string) (err error) { +type StreamCopier interface { + // WriteStreamTo writes the stream represented by this StreamCopier + // to the given io.Writer. + WriteStreamTo(w io.Writer) StreamCopierError + // Close must be called as soon as it is clear that no more data will + // be read from the StreamCopier. + // If StreamCopier gets its data from a connection, it might hold + // a lock on the connection until Close is called. Only closing ensures + // that the connection can be used afterwards. + Close() error +} + + +func ZFSRecv(ctx context.Context, fs string, streamCopier StreamCopier, additionalArgs ...string) (err error) { if err := validateZFSFilesystem(fs); err != nil { return err @@ -468,6 +705,8 @@ func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs .. } args = append(args, fs) + ctx, cancelCmd := context.WithCancel(ctx) + defer cancelCmd() cmd := exec.CommandContext(ctx, ZFS_BINARY, args...) stderr := bytes.NewBuffer(make([]byte, 0, 1024)) @@ -480,21 +719,60 @@ func ZFSRecv(ctx context.Context, fs string, stream io.Reader, additionalArgs .. stdout := bytes.NewBuffer(make([]byte, 0, 1024)) cmd.Stdout = stdout - cmd.Stdin = stream + stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint) + if err != nil { + return err + } + + cmd.Stdin = stdin if err = cmd.Start(); err != nil { - return + stdinWriter.Close() + stdin.Close() + return err + } + stdin.Close() + defer stdinWriter.Close() + + pid := cmd.Process.Pid + debug := func(format string, args ...interface{}) { + debug("recv: pid=%v: %s", pid, fmt.Sprintf(format, args...)) } - if err = cmd.Wait(); err != nil { - err = ZFSError{ - Stderr: stderr.Bytes(), - WaitErr: err, + debug("started") + + copierErrChan := make(chan StreamCopierError) + go func() { + copierErrChan <- streamCopier.WriteStreamTo(stdinWriter) + }() + waitErrChan := make(chan *ZFSError) + go func() { + defer close(waitErrChan) + if err = cmd.Wait(); err != nil { + waitErrChan <- &ZFSError{ + Stderr: stderr.Bytes(), + WaitErr: err, + } + return } - return + }() + + // streamCopier always fails before or simultaneously with Wait + // thus receive from it first + copierErr := <-copierErrChan + debug("copierErr: %T %s", copierErr, copierErr) + if copierErr != nil { + cancelCmd() } - return nil + waitErr := <-waitErrChan + debug("waitErr: %T %s", waitErr, waitErr) + if copierErr == nil && waitErr == nil { + return nil + } else if waitErr != nil && (copierErr == nil || copierErr.IsWriteError()) { + return waitErr // has more interesting info in that case + } + return copierErr // if it's not a write error, the copier error is more interesting } type ClearResumeTokenError struct { @@ -572,7 +850,7 @@ func zfsSet(path string, props *ZFSProperties) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -689,7 +967,7 @@ func ZFSDestroy(dataset string) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -723,7 +1001,7 @@ func ZFSSnapshot(fs *DatasetPath, name string, recursive bool) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } @@ -751,7 +1029,7 @@ func ZFSBookmark(fs *DatasetPath, snapshot, bookmark string) (err error) { } if err = cmd.Wait(); err != nil { - err = ZFSError{ + err = &ZFSError{ Stderr: stderr.Bytes(), WaitErr: err, } diff --git a/zfs/zfs_debug.go b/zfs/zfs_debug.go new file mode 100644 index 0000000..32846e4 --- /dev/null +++ b/zfs/zfs_debug.go @@ -0,0 +1,20 @@ +package zfs + +import ( + "fmt" + "os" +) + +var debugEnabled bool = false + +func init() { + if os.Getenv("ZREPL_ZFS_DEBUG") != "" { + debugEnabled = true + } +} + +func debug(format string, args ...interface{}) { + if debugEnabled { + fmt.Fprintf(os.Stderr, "zfs: %s\n", fmt.Sprintf(format, args...)) + } +} diff --git a/zfs/zfs_pipe.go b/zfs/zfs_pipe.go new file mode 100644 index 0000000..4816441 --- /dev/null +++ b/zfs/zfs_pipe.go @@ -0,0 +1,18 @@ +// +build !linux + +package zfs + +import ( + "os" + "sync" +) + +var zfsPipeCapacityNotSupported sync.Once + +func trySetPipeCapacity(p *os.File, capacity int) { + if debugEnabled { + zfsPipeCapacityNotSupported.Do(func() { + debug("trySetPipeCapacity error: OS does not support setting pipe capacity") + }) + } +} diff --git a/zfs/zfs_pipe_linux.go b/zfs/zfs_pipe_linux.go new file mode 100644 index 0000000..ab50020 --- /dev/null +++ b/zfs/zfs_pipe_linux.go @@ -0,0 +1,21 @@ +package zfs + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +func trySetPipeCapacity(p *os.File, capacity int) { + res, err := unix.FcntlInt(p.Fd(), unix.F_SETPIPE_SZ, capacity) + if err != nil { + err = fmt.Errorf("cannot set pipe capacity to %v", capacity) + } else if res == -1 { + err = errors.New("cannot set pipe capacity: fcntl returned -1") + } + if debugEnabled && err != nil { + debug("trySetPipeCapacity error: %s\n", err) + } +} diff --git a/zfs/zfs_test.go b/zfs/zfs_test.go index 6ffdfdf..9126bb0 100644 --- a/zfs/zfs_test.go +++ b/zfs/zfs_test.go @@ -15,7 +15,7 @@ func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) { _, err = ZFSList([]string{"fictionalprop"}, "nonexistent/dataset") assert.Error(t, err) - zfsError, ok := err.(ZFSError) + zfsError, ok := err.(*ZFSError) assert.True(t, ok) assert.Equal(t, "error: this is a mock\n", string(zfsError.Stderr)) } From 0230c6321f4f1b27b77e4b7efaf3d6b5d851c9c1 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sun, 23 Dec 2018 13:45:35 +0100 Subject: [PATCH 17/20] rpc/dataconn: microbenchmark --- Gopkg.lock | 9 ++ rpc/dataconn/microbenchmark/microbenchmark.go | 141 ++++++++++++------ util/devnoop/devnoop.go | 14 ++ 3 files changed, 118 insertions(+), 46 deletions(-) create mode 100644 util/devnoop/devnoop.go diff --git a/Gopkg.lock b/Gopkg.lock index c670de5..4674a95 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -177,6 +177,14 @@ revision = "645ef00459ed84a119197bfb8d8205042c6df63d" version = "v0.8.0" +[[projects]] + digest = "1:1cbc6b98173422a756ae79e485952cb37a0a460c710541c75d3e9961c5a60782" + name = "github.com/pkg/profile" + packages = ["."] + pruneopts = "" + revision = "5b67d428864e92711fcbd2f8629456121a56d91f" + version = "v1.2.1" + [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" name = "github.com/pmezard/go-difflib" @@ -399,6 +407,7 @@ "github.com/kr/pretty", "github.com/mattn/go-isatty", "github.com/pkg/errors", + "github.com/pkg/profile", "github.com/problame/go-netssh", "github.com/prometheus/client_golang/prometheus", "github.com/prometheus/client_golang/prometheus/promhttp", diff --git a/rpc/dataconn/microbenchmark/microbenchmark.go b/rpc/dataconn/microbenchmark/microbenchmark.go index 5ee58cf..287f7e1 100644 --- a/rpc/dataconn/microbenchmark/microbenchmark.go +++ b/rpc/dataconn/microbenchmark/microbenchmark.go @@ -1,7 +1,19 @@ +// microbenchmark to manually test rpc/dataconn perforamnce +// +// With stdin / stdout on client and server, simulating zfs send|recv piping +// +// ./microbenchmark -appmode server | pv -r > /dev/null +// ./microbenchmark -appmode client -direction recv < /dev/zero +// +// +// Without the overhead of pipes (just protocol perforamnce, mostly useful with perf bc no bw measurement) +// +// ./microbenchmark -appmode client -direction recv -devnoopWriter -devnoopReader +// ./microbenchmark -appmode server -devnoopReader -devnoopWriter +// package main import ( - "bytes" "context" "flag" "fmt" @@ -10,9 +22,14 @@ import ( "os" "github.com/pkg/profile" - "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/replication/pdu" + "github.com/zrepl/zrepl/rpc/dataconn" + "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" + "github.com/zrepl/zrepl/transport" + "github.com/zrepl/zrepl/util/devnoop" + "github.com/zrepl/zrepl/zfs" ) func orDie(err error) { @@ -21,54 +38,93 @@ func orDie(err error) { } } -type devNullHandler struct{} +type readerStreamCopier struct{ io.Reader } -func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - var res pdu.SendRes - return &res, os.Stdin, nil +func (readerStreamCopier) Close() error { return nil } + +type readerStreamCopierErr struct { + error } -func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream io.Reader) (*pdu.ReceiveRes, error) { - var buf [1<<15]byte - _, err := io.CopyBuffer(os.Stdout, stream, buf[:]) +func (readerStreamCopierErr) IsReadError() bool { return false } +func (readerStreamCopierErr) IsWriteError() bool { return true } + +func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { + var buf [1 << 21]byte + _, err := io.CopyBuffer(w, c.Reader, buf[:]) + // always assume write error + return readerStreamCopierErr{err} +} + +type devNullHandler struct{} + +func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { + var res pdu.SendRes + if args.devnoopReader { + return &res, readerStreamCopier{devnoop.Get()}, nil + } else { + return &res, readerStreamCopier{os.Stdin}, nil + } +} + +func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) { + var out io.Writer = os.Stdout + if args.devnoopWriter { + out = devnoop.Get() + } + err := stream.WriteStreamTo(out) var res pdu.ReceiveRes return &res, err } type tcpConnecter struct { - net, addr string + addr string } -func (c tcpConnecter) Connect(ctx context.Context) (net.Conn, error) { - return net.Dial(c.net, c.addr) +func (c tcpConnecter) Connect(ctx context.Context) (timeoutconn.Wire, error) { + conn, err := net.Dial("tcp", c.addr) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil } +type tcpListener struct { + nl *net.TCPListener + clientIdent string +} + +func (l tcpListener) Accept(ctx context.Context) (*transport.AuthConn, error) { + tcpconn, err := l.nl.AcceptTCP() + orDie(err) + return transport.NewAuthConn(tcpconn, l.clientIdent), nil +} + +func (l tcpListener) Addr() net.Addr { return l.nl.Addr() } + +func (l tcpListener) Close() error { return l.nl.Close() } + var args struct { - addr string - appmode string - direction string - profile bool + addr string + appmode string + direction string + profile bool + devnoopReader bool + devnoopWriter bool } func server() { log := logger.NewStderrDebugLogger() log.Debug("starting server") - l, err := net.Listen("tcp", args.addr) + nl, err := net.Listen("tcp", args.addr) orDie(err) + l := tcpListener{nl.(*net.TCPListener), "fakeclientidentity"} - srvConfig := dataconn.ServerConfig{ - Shared: dataconn.SharedConfig { - MaxProtoLen: 4096, - MaxHeaderLen: 4096, - SendChunkSize: 1 << 17, - MaxRecvChunkSize: 1 << 17, - }, - } - srv := dataconn.NewServer(devNullHandler{}, srvConfig, nil) + srv := dataconn.NewServer(nil, logger.NewStderrDebugLogger(), devNullHandler{}) ctx := context.Background() - ctx = dataconn.WithLogger(ctx, log) + srv.Serve(ctx, l) } @@ -76,6 +132,8 @@ func server() { func main() { flag.BoolVar(&args.profile, "profile", false, "") + flag.BoolVar(&args.devnoopReader, "devnoopReader", false, "") + flag.BoolVar(&args.devnoopWriter, "devnoopWriter", false, "") flag.StringVar(&args.addr, "address", ":8888", "") flag.StringVar(&args.appmode, "appmode", "client|server", "") flag.StringVar(&args.direction, "direction", "", "send|recv") @@ -99,34 +157,25 @@ func client() { logger := logger.NewStderrDebugLogger() ctx := context.Background() - ctx = dataconn.WithLogger(ctx, logger) - clientConfig := dataconn.ClientConfig{ - Shared: dataconn.SharedConfig { - MaxProtoLen: 4096, - MaxHeaderLen: 4096, - SendChunkSize: 1 << 17, - MaxRecvChunkSize: 1 << 17, - }, - } - orDie(clientConfig.Validate()) - - connecter := tcpConnecter{"tcp", args.addr} - client := dataconn.NewClient(connecter, clientConfig) + connecter := tcpConnecter{args.addr} + client := dataconn.NewClient(connecter, logger) switch args.direction { case "send": req := pdu.SendReq{} - _, stream, err := client.ReqSendStream(ctx, &req) + _, stream, err := client.ReqSend(ctx, &req) orDie(err) - var buf [1<<15]byte - _, err = io.CopyBuffer(os.Stdout, stream, buf[:]) + err = stream.WriteStreamTo(os.Stdout) orDie(err) case "recv": - var buf bytes.Buffer - buf.WriteString("teststreamtobereceived") + var r io.Reader = os.Stdin + if args.devnoopReader { + r = devnoop.Get() + } + s := readerStreamCopier{r} req := pdu.ReceiveReq{} - _, err := client.ReqRecv(ctx, &req, os.Stdin) + _, err := client.ReqRecv(ctx, &req, &s) orDie(err) default: orDie(fmt.Errorf("unknown direction%q", args.direction)) diff --git a/util/devnoop/devnoop.go b/util/devnoop/devnoop.go new file mode 100644 index 0000000..9c4074f --- /dev/null +++ b/util/devnoop/devnoop.go @@ -0,0 +1,14 @@ +// package devnoop provides an io.ReadWriteCloser that never errors +// and always reports reads / writes to / from buffers as complete. +// The buffers themselves are never touched. +package devnoop + +type Dev struct{} + +func Get() Dev { + return Dev{} +} + +func (Dev) Write(p []byte) (n int, err error) { return len(p), nil } +func (Dev) Read(p []byte) (n int, err error) { return len(p), nil } +func (Dev) Close() error { return nil } From a7993d18c65206c14b37396edd09faea5370969a Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 15 Mar 2019 17:17:25 +0100 Subject: [PATCH 18/20] transport/tls: clarify docs & error message language --- docs/configuration/transports.rst | 17 ++++++++++------- tlsconf/tlsconf.go | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/configuration/transports.rst b/docs/configuration/transports.rst index 20d1b25..075e918 100644 --- a/docs/configuration/transports.rst +++ b/docs/configuration/transports.rst @@ -76,9 +76,8 @@ Connect The ``tls`` transport uses TCP + TLS with client authentication using client certificates. The client identity is the common name (CN) presented in the client certificate. -It is recommended to set up a dedicated CA infrastructure for this transport, e.g. using OpenVPN's `EasyRSA `_. -When utilizing a CA infrastructure, provide a full chain certificate with the sender's certificate first in the list, with each following certificate directly certifying the one preceding it, per `TLS's specification`. +It is recommended to set up a dedicated CA infrastructure for this transport, e.g. using OpenVPN's `EasyRSA `_. For a simple 2-machine setup, see the :ref:`instructions below`. The implementation uses `Go's TLS library `_. @@ -87,6 +86,10 @@ Since Go binaries are statically linked, you or your distribution need to recomp All file paths are resolved relative to the zrepl daemon's working directory. Specify absolute paths if you are unsure what directory that is (or find out from your init system). +If intermediate CAs are used, the **full chain** must be present in either in the ``ca`` file or the individual ``cert`` files. +Regardless, the client's certificate must be first in the ``cert`` file, with each following certificate directly certifying the one preceding it (see `TLS's specification `_). +This is the common default when using a CA management tool. + Serve ~~~~~ @@ -98,9 +101,9 @@ Serve serve: type: tls listen: ":8888" - ca: /etc/zrepl/ca.crt - cert: /etc/zrepl/prod.crt - key: /etc/zrepl/prod.key + ca: /etc/zrepl/ca.crt + cert: /etc/zrepl/prod.fullchain + key: /etc/zrepl/prod.key client_cns: - "laptop1" - "homeserver" @@ -118,8 +121,8 @@ Connect connect: type: tls address: "server1.foo.bar:8888" - ca: /etc/zrepl/ca.crt - cert: /etc/zrepl/backupserver.crt + ca: /etc/zrepl/ca.crt + cert: /etc/zrepl/backupserver.fullchain key: /etc/zrepl/backupserver.key server_cn: "server1" dial_timeout: # optional, default 10s diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index a5a4ea5..b1cb554 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -84,7 +84,7 @@ func (l *ClientAuthListener) Accept() (tcpConn *net.TCPConn, tlsConn *tls.Conn, peerCerts = tlsConn.ConnectionState().PeerCertificates if len(peerCerts) < 1 { - err = errors.New("unexpected number of certificates presented by TLS client") + err = errors.New("client must present full RFC5246:7.4.2 TLS client certificate chain") goto CloseAndErr } cn = peerCerts[0].Subject.CommonName From fc311a9fd6985b1c019ed7ab34b7657f6ab86e9b Mon Sep 17 00:00:00 2001 From: Ximalas Date: Fri, 1 Feb 2019 21:44:51 +0100 Subject: [PATCH 19/20] syslog logging: support setting facility in config --- config/config.go | 40 ++++++++++++++++++++++++++++++- config/config_global_test.go | 31 ++++++++++++++++++++++++ daemon/logging/build_logging.go | 1 + daemon/logging/logging_outlets.go | 3 ++- docs/configuration/logging.rst | 2 ++ 5 files changed, 75 insertions(+), 2 deletions(-) diff --git a/config/config.go b/config/config.go index 3c2131a..eb290b7 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" "github.com/zrepl/yaml-config" "io/ioutil" + "log/syslog" "os" "reflect" "regexp" @@ -38,7 +39,7 @@ func (j JobEnum) Name() string { case *PullJob: name = v.Name case *SourceJob: name = v.Name default: - panic(fmt.Sprintf("unknownn job type %T", v)) + panic(fmt.Sprintf("unknown job type %T", v)) } return name } @@ -258,6 +259,7 @@ type StdoutLoggingOutlet struct { type SyslogLoggingOutlet struct { LoggingOutletCommon `yaml:",inline"` + Facility syslog.Priority `yaml:"facility,default=local0"` RetryInterval time.Duration `yaml:"retry_interval,positive,default=10s"` } @@ -284,6 +286,16 @@ type PrometheusMonitoring struct { Listen string `yaml:"listen"` } +type SyslogFacilityEnum struct { + Ret interface{} +} + +type SyslogFacilityEnumList []SyslogFacilityEnum + +type SyslogFacility struct { + Facility syslog.Priority +} + type GlobalControl struct { SockPath string `yaml:"sockpath,default=/var/run/zrepl/control"` } @@ -389,6 +401,32 @@ func (t *MonitoringEnum) UnmarshalYAML(u func(interface{}, bool) error) (err err return } +func (t *SyslogFacilityEnum) UnmarshalYAML(u func(interface{}, bool) error) (err error) { + t.Ret, err = enumUnmarshal(u, map[string]interface{}{ + "kern": &SyslogFacility{syslog.LOG_KERN}, + "user": &SyslogFacility{syslog.LOG_USER}, + "mail": &SyslogFacility{syslog.LOG_MAIL}, + "daemon": &SyslogFacility{syslog.LOG_DAEMON}, + "auth": &SyslogFacility{syslog.LOG_AUTH}, + "syslog": &SyslogFacility{syslog.LOG_SYSLOG}, + "lpr": &SyslogFacility{syslog.LOG_LPR}, + "news": &SyslogFacility{syslog.LOG_NEWS}, + "uucp": &SyslogFacility{syslog.LOG_UUCP}, + "cron": &SyslogFacility{syslog.LOG_CRON}, + "authpriv": &SyslogFacility{syslog.LOG_AUTHPRIV}, + "ftp": &SyslogFacility{syslog.LOG_FTP}, + "local0": &SyslogFacility{syslog.LOG_LOCAL0}, + "local1": &SyslogFacility{syslog.LOG_LOCAL1}, + "local2": &SyslogFacility{syslog.LOG_LOCAL2}, + "local3": &SyslogFacility{syslog.LOG_LOCAL3}, + "local4": &SyslogFacility{syslog.LOG_LOCAL4}, + "local5": &SyslogFacility{syslog.LOG_LOCAL5}, + "local6": &SyslogFacility{syslog.LOG_LOCAL6}, + "local7": &SyslogFacility{syslog.LOG_LOCAL7}, + }) + return +} + var ConfigFileDefaultLocations = []string{ "/etc/zrepl/zrepl.yml", "/usr/local/etc/zrepl/zrepl.yml", diff --git a/config/config_global_test.go b/config/config_global_test.go index e14b936..a6c4f7b 100644 --- a/config/config_global_test.go +++ b/config/config_global_test.go @@ -1,9 +1,11 @@ package config import ( + "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zrepl/yaml-config" + "log/syslog" "testing" ) @@ -72,6 +74,35 @@ global: assert.Equal(t, ":9091", conf.Global.Monitoring[0].Ret.(*PrometheusMonitoring).Listen) } +func TestSyslogLoggingOutletFacility(t *testing.T) { + type SyslogFacilityPriority struct { + Facility string + Priority syslog.Priority + } + syslogFacilitiesPriorities := []SyslogFacilityPriority{ + {"kern", syslog.LOG_KERN}, {"daemon", syslog.LOG_DAEMON}, {"auth", syslog.LOG_AUTH}, + {"syslog", syslog.LOG_SYSLOG}, {"lpr", syslog.LOG_LPR}, {"news", syslog.LOG_NEWS}, + {"uucp", syslog.LOG_UUCP}, {"cron", syslog.LOG_CRON}, {"authpriv", syslog.LOG_AUTHPRIV}, + {"ftp", syslog.LOG_FTP}, {"local0", syslog.LOG_LOCAL0}, {"local1", syslog.LOG_LOCAL1}, + {"local2", syslog.LOG_LOCAL2}, {"local3", syslog.LOG_LOCAL3}, {"local4", syslog.LOG_LOCAL4}, + {"local5", syslog.LOG_LOCAL5}, {"local6", syslog.LOG_LOCAL6}, {"local7", syslog.LOG_LOCAL7}, + } + + for _, sFP := range syslogFacilitiesPriorities { + logcfg := fmt.Sprintf(` +global: + logging: + - type: syslog + level: info + format: human + facility: %s +`, sFP.Facility) + conf := testValidGlobalSection(t, logcfg) + assert.Equal(t, 1, len(*conf.Global.Logging)) + assert.Equal(t, sFP.Priority, (*conf.Global.Logging)[0].Ret.(*SyslogLoggingOutlet).Facility) + } +} + func TestLoggingOutletEnumList_SetDefaults(t *testing.T) { e := &LoggingOutletEnumList{} var i yaml.Defaulter = e diff --git a/daemon/logging/build_logging.go b/daemon/logging/build_logging.go index bcefec5..e4c2634 100644 --- a/daemon/logging/build_logging.go +++ b/daemon/logging/build_logging.go @@ -222,6 +222,7 @@ func parseSyslogOutlet(in *config.SyslogLoggingOutlet, formatter EntryFormatter) out = &SyslogOutlet{} out.Formatter = formatter out.Formatter.SetMetadataFlags(MetadataNone) + out.Facility = in.Facility out.RetryInterval = in.RetryInterval return out, nil } diff --git a/daemon/logging/logging_outlets.go b/daemon/logging/logging_outlets.go index 5a00d42..b03d008 100644 --- a/daemon/logging/logging_outlets.go +++ b/daemon/logging/logging_outlets.go @@ -124,6 +124,7 @@ func (h *TCPOutlet) WriteEntry(e logger.Entry) error { type SyslogOutlet struct { Formatter EntryFormatter RetryInterval time.Duration + Facility syslog.Priority writer *syslog.Writer lastConnectAttempt time.Time } @@ -142,7 +143,7 @@ func (o *SyslogOutlet) WriteEntry(entry logger.Entry) error { if now.Sub(o.lastConnectAttempt) < o.RetryInterval { return nil // not an error toward logger } - o.writer, err = syslog.New(syslog.LOG_LOCAL0, "zrepl") + o.writer, err = syslog.New(o.Facility, "zrepl") o.lastConnectAttempt = time.Now() if err != nil { o.writer = nil diff --git a/docs/configuration/logging.rst b/docs/configuration/logging.rst index f1a8466..a9077bb 100644 --- a/docs/configuration/logging.rst +++ b/docs/configuration/logging.rst @@ -147,6 +147,8 @@ Can only be specified once. - minimum :ref:`log level ` * - ``format`` - output :ref:`format ` + * - ``facility`` + - Which syslog facility to use (default = ``local0``) * - ``retry_interval`` - Interval between reconnection attempts to syslog (default = 0) From a0f301d7005835636d369be40e18c76357fa847e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Fri, 15 Mar 2019 17:44:41 +0100 Subject: [PATCH 20/20] syslog logging: fix priority parsing + add test for default facility --- config/config.go | 70 ++++++++++++++++++--------------- config/config_global_test.go | 3 +- daemon/logging/build_logging.go | 3 +- 3 files changed, 42 insertions(+), 34 deletions(-) diff --git a/config/config.go b/config/config.go index eb290b7..7ff0ca9 100644 --- a/config/config.go +++ b/config/config.go @@ -259,7 +259,7 @@ type StdoutLoggingOutlet struct { type SyslogLoggingOutlet struct { LoggingOutletCommon `yaml:",inline"` - Facility syslog.Priority `yaml:"facility,default=local0"` + Facility *SyslogFacility `yaml:"facility,optional,fromdefaults"` RetryInterval time.Duration `yaml:"retry_interval,positive,default=10s"` } @@ -286,15 +286,13 @@ type PrometheusMonitoring struct { Listen string `yaml:"listen"` } -type SyslogFacilityEnum struct { - Ret interface{} +type SyslogFacility syslog.Priority + +func (f *SyslogFacility) SetDefault() { + *f = SyslogFacility(syslog.LOG_LOCAL0) } -type SyslogFacilityEnumList []SyslogFacilityEnum - -type SyslogFacility struct { - Facility syslog.Priority -} +var _ yaml.Defaulter = (*SyslogFacility)(nil) type GlobalControl struct { SockPath string `yaml:"sockpath,default=/var/run/zrepl/control"` @@ -401,30 +399,38 @@ func (t *MonitoringEnum) UnmarshalYAML(u func(interface{}, bool) error) (err err return } -func (t *SyslogFacilityEnum) UnmarshalYAML(u func(interface{}, bool) error) (err error) { - t.Ret, err = enumUnmarshal(u, map[string]interface{}{ - "kern": &SyslogFacility{syslog.LOG_KERN}, - "user": &SyslogFacility{syslog.LOG_USER}, - "mail": &SyslogFacility{syslog.LOG_MAIL}, - "daemon": &SyslogFacility{syslog.LOG_DAEMON}, - "auth": &SyslogFacility{syslog.LOG_AUTH}, - "syslog": &SyslogFacility{syslog.LOG_SYSLOG}, - "lpr": &SyslogFacility{syslog.LOG_LPR}, - "news": &SyslogFacility{syslog.LOG_NEWS}, - "uucp": &SyslogFacility{syslog.LOG_UUCP}, - "cron": &SyslogFacility{syslog.LOG_CRON}, - "authpriv": &SyslogFacility{syslog.LOG_AUTHPRIV}, - "ftp": &SyslogFacility{syslog.LOG_FTP}, - "local0": &SyslogFacility{syslog.LOG_LOCAL0}, - "local1": &SyslogFacility{syslog.LOG_LOCAL1}, - "local2": &SyslogFacility{syslog.LOG_LOCAL2}, - "local3": &SyslogFacility{syslog.LOG_LOCAL3}, - "local4": &SyslogFacility{syslog.LOG_LOCAL4}, - "local5": &SyslogFacility{syslog.LOG_LOCAL5}, - "local6": &SyslogFacility{syslog.LOG_LOCAL6}, - "local7": &SyslogFacility{syslog.LOG_LOCAL7}, - }) - return +func (t *SyslogFacility) UnmarshalYAML(u func(interface{}, bool) error) (err error) { + var s string + if err := u(&s, true); err != nil { + return err + } + var level syslog.Priority + switch s { + case "kern": level = syslog.LOG_KERN + case "user": level = syslog.LOG_USER + case "mail": level = syslog.LOG_MAIL + case "daemon": level = syslog.LOG_DAEMON + case "auth": level = syslog.LOG_AUTH + case "syslog": level = syslog.LOG_SYSLOG + case "lpr": level = syslog.LOG_LPR + case "news": level = syslog.LOG_NEWS + case "uucp": level = syslog.LOG_UUCP + case "cron": level = syslog.LOG_CRON + case "authpriv": level = syslog.LOG_AUTHPRIV + case "ftp": level = syslog.LOG_FTP + case "local0": level = syslog.LOG_LOCAL0 + case "local1": level = syslog.LOG_LOCAL1 + case "local2": level = syslog.LOG_LOCAL2 + case "local3": level = syslog.LOG_LOCAL3 + case "local4": level = syslog.LOG_LOCAL4 + case "local5": level = syslog.LOG_LOCAL5 + case "local6": level = syslog.LOG_LOCAL6 + case "local7": level = syslog.LOG_LOCAL7 + default: + return fmt.Errorf("invalid syslog level: %q", s) + } + *t = SyslogFacility(level) + return nil } var ConfigFileDefaultLocations = []string{ diff --git a/config/config_global_test.go b/config/config_global_test.go index a6c4f7b..51204b0 100644 --- a/config/config_global_test.go +++ b/config/config_global_test.go @@ -80,6 +80,7 @@ func TestSyslogLoggingOutletFacility(t *testing.T) { Priority syslog.Priority } syslogFacilitiesPriorities := []SyslogFacilityPriority{ + {"", syslog.LOG_LOCAL0}, // default {"kern", syslog.LOG_KERN}, {"daemon", syslog.LOG_DAEMON}, {"auth", syslog.LOG_AUTH}, {"syslog", syslog.LOG_SYSLOG}, {"lpr", syslog.LOG_LPR}, {"news", syslog.LOG_NEWS}, {"uucp", syslog.LOG_UUCP}, {"cron", syslog.LOG_CRON}, {"authpriv", syslog.LOG_AUTHPRIV}, @@ -99,7 +100,7 @@ global: `, sFP.Facility) conf := testValidGlobalSection(t, logcfg) assert.Equal(t, 1, len(*conf.Global.Logging)) - assert.Equal(t, sFP.Priority, (*conf.Global.Logging)[0].Ret.(*SyslogLoggingOutlet).Facility) + assert.True(t, SyslogFacility(sFP.Priority) == *(*conf.Global.Logging)[0].Ret.(*SyslogLoggingOutlet).Facility) } } diff --git a/daemon/logging/build_logging.go b/daemon/logging/build_logging.go index e4c2634..ce90d3c 100644 --- a/daemon/logging/build_logging.go +++ b/daemon/logging/build_logging.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "log/syslog" "os" "github.com/mattn/go-isatty" @@ -222,7 +223,7 @@ func parseSyslogOutlet(in *config.SyslogLoggingOutlet, formatter EntryFormatter) out = &SyslogOutlet{} out.Formatter = formatter out.Formatter.SetMetadataFlags(MetadataNone) - out.Facility = in.Facility + out.Facility = syslog.Priority(*in.Facility) out.RetryInterval = in.RetryInterval return out, nil }