mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 00:13:52 +01:00
Merge branch 'master' into fix_peer_cert_chains
This commit is contained in:
commit
5595cff6a6
38
.travis.yml
38
.travis.yml
@ -16,25 +16,6 @@ matrix:
|
||||
zrepl_build make vendordeps release
|
||||
|
||||
# all go entries vary only by go version
|
||||
- language: go
|
||||
go:
|
||||
- "1.10"
|
||||
go_import_path: github.com/zrepl/zrepl
|
||||
before_install:
|
||||
- wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip
|
||||
- echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c
|
||||
- sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip
|
||||
- ./lazy.sh godep
|
||||
- make vendordeps
|
||||
script:
|
||||
- make
|
||||
- make vet
|
||||
- make test
|
||||
- go test ./...
|
||||
- make artifacts/zrepl-freebsd-amd64
|
||||
- make artifacts/zrepl-linux-amd64
|
||||
- make artifacts/zrepl-darwin-amd64
|
||||
|
||||
- language: go
|
||||
go:
|
||||
- "1.11"
|
||||
@ -49,7 +30,24 @@ matrix:
|
||||
- make
|
||||
- make vet
|
||||
- make test
|
||||
- go test ./...
|
||||
- make artifacts/zrepl-freebsd-amd64
|
||||
- make artifacts/zrepl-linux-amd64
|
||||
- make artifacts/zrepl-darwin-amd64
|
||||
|
||||
- language: go
|
||||
go:
|
||||
- "master"
|
||||
go_import_path: github.com/zrepl/zrepl
|
||||
before_install:
|
||||
- wget https://github.com/protocolbuffers/protobuf/releases/download/v3.6.1/protoc-3.6.1-linux-x86_64.zip
|
||||
- echo "6003de742ea3fcf703cfec1cd4a3380fd143081a2eb0e559065563496af27807 protoc-3.6.1-linux-x86_64.zip" | sha256sum -c
|
||||
- sudo unzip -d /usr protoc-3.6.1-linux-x86_64.zip
|
||||
- ./lazy.sh godep
|
||||
- make vendordeps
|
||||
script:
|
||||
- make
|
||||
- make vet
|
||||
- make test
|
||||
- make artifacts/zrepl-freebsd-amd64
|
||||
- make artifacts/zrepl-linux-amd64
|
||||
- make artifacts/zrepl-darwin-amd64
|
||||
|
134
Gopkg.lock
generated
134
Gopkg.lock
generated
@ -80,6 +80,10 @@
|
||||
"protoc-gen-go/generator/internal/remap",
|
||||
"protoc-gen-go/grpc",
|
||||
"protoc-gen-go/plugin",
|
||||
"ptypes",
|
||||
"ptypes/any",
|
||||
"ptypes/duration",
|
||||
"ptypes/timestamp",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "aa810b61a9c79d51363740d207bb46cf8e620ed5"
|
||||
@ -173,6 +177,14 @@
|
||||
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
|
||||
version = "v0.8.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:1cbc6b98173422a756ae79e485952cb37a0a460c710541c75d3e9961c5a60782"
|
||||
name = "github.com/pkg/profile"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "5b67d428864e92711fcbd2f8629456121a56d91f"
|
||||
version = "v1.2.1"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
|
||||
name = "github.com/pmezard/go-difflib"
|
||||
@ -183,30 +195,11 @@
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:25559b520313b941b1395cd5d5ee66086b27dc15a1391c0f2aad29d5c2321f4b"
|
||||
digest = "1:fa72f780ae3b4820ed12cef7a034291ab10d83e2da4ab5ba81afa44d5bf3a529"
|
||||
name = "github.com/problame/go-netssh"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "c56ad38d2c91397ad3c8dd9443d7448e328a9e9e"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:8c63c44f018bd52b03ebad65c9df26aabbc6793138e421df1c8c84c285a45bc6"
|
||||
name = "github.com/problame/go-rwccmd"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:1bcbb0a7ad8d3392d446eb583ae5415ff987838a8f7331a36877789be20667e6"
|
||||
name = "github.com/problame/go-streamrpc"
|
||||
packages = [
|
||||
".",
|
||||
"internal/pdu",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "d5d111e014342fe1c37f0b71cc37ec5f2afdfd13"
|
||||
version = "v0.5"
|
||||
revision = "09d6bc45d284784cb3e5aaa1998513f37eb19cc6"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
@ -279,31 +272,66 @@
|
||||
revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0"
|
||||
version = "v1.1.4"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:f80053a92d9ac874ad97d665d3005c1118ed66e9e88401361dc32defe6bef21c"
|
||||
name = "github.com/theckman/goconstraint"
|
||||
packages = ["go1.11/gte"]
|
||||
pruneopts = ""
|
||||
revision = "93babf24513d0e8277635da8169fcc5a46ae3f6a"
|
||||
version = "v1.11.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "v2"
|
||||
digest = "1:9d92186f609a73744232323416ddafd56fae67cb552162cc190ab903e36900dd"
|
||||
digest = "1:6b8a6afafde7ed31cd0c577ba40d88ce39e8f1c5eb76d7836be7d5b74f1c534a"
|
||||
name = "github.com/zrepl/yaml-config"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "af27d27978ad95808723a62d87557d63c3ff0605"
|
||||
revision = "08227ad854131f7dfcdfb12579fb73dd8a38a03a"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:9c286cf11d0ca56368185bada5dd6d97b6be4648fc26c354fcba8df7293718f7"
|
||||
digest = "1:ea539c13b066dac72a940b62f37600a20ab8e88057397c78f3197c1a48475425"
|
||||
name = "golang.org/x/net"
|
||||
packages = [
|
||||
"context",
|
||||
"http/httpguts",
|
||||
"http2",
|
||||
"http2/hpack",
|
||||
"idna",
|
||||
"internal/timeseries",
|
||||
"trace",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "351d144fa1fc0bd934e2408202be0c29f25e35a0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:f358024b019f87eecaadcb098113a40852c94fe58ea670ef3c3e2d2c7bd93db1"
|
||||
name = "golang.org/x/sys"
|
||||
packages = ["unix"]
|
||||
pruneopts = ""
|
||||
revision = "bf42f188b9bc6f2cf5b8ee5a912ef1aedd0eba4c"
|
||||
revision = "4ed8d59d0b35e1e29334a206d1b3f38b1e5dfb31"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:5acd3512b047305d49e8763eef7ba423901e85d5dd2fd1e71778a0ea8de10bd4"
|
||||
name = "golang.org/x/text"
|
||||
packages = [
|
||||
"collate",
|
||||
"collate/build",
|
||||
"encoding",
|
||||
"encoding/internal/identifier",
|
||||
"internal/colltab",
|
||||
"internal/gen",
|
||||
"internal/tag",
|
||||
"internal/triegen",
|
||||
"internal/ucd",
|
||||
"language",
|
||||
"secure/bidirule",
|
||||
"transform",
|
||||
"unicode/bidi",
|
||||
"unicode/cldr",
|
||||
"unicode/norm",
|
||||
"unicode/rangetable",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0"
|
||||
@ -317,6 +345,54 @@
|
||||
pruneopts = ""
|
||||
revision = "d0ca3933b724e6be513276cc2edb34e10d667438"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:5fc6c317675b746d0c641b29aa0aab5fcb403c0d07afdbf0de86b0d447a0502a"
|
||||
name = "google.golang.org/genproto"
|
||||
packages = ["googleapis/rpc/status"]
|
||||
pruneopts = ""
|
||||
revision = "bd91e49a0898e27abb88c339b432fa53d7497ac0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:d141efe4aaad714e3059c340901aab3147b6253e58c85dafbcca3dd8b0e88ad6"
|
||||
name = "google.golang.org/grpc"
|
||||
packages = [
|
||||
".",
|
||||
"balancer",
|
||||
"balancer/base",
|
||||
"balancer/roundrobin",
|
||||
"binarylog/grpc_binarylog_v1",
|
||||
"codes",
|
||||
"connectivity",
|
||||
"credentials",
|
||||
"credentials/internal",
|
||||
"encoding",
|
||||
"encoding/proto",
|
||||
"grpclog",
|
||||
"internal",
|
||||
"internal/backoff",
|
||||
"internal/binarylog",
|
||||
"internal/channelz",
|
||||
"internal/envconfig",
|
||||
"internal/grpcrand",
|
||||
"internal/grpcsync",
|
||||
"internal/syscall",
|
||||
"internal/transport",
|
||||
"keepalive",
|
||||
"metadata",
|
||||
"naming",
|
||||
"peer",
|
||||
"resolver",
|
||||
"resolver/dns",
|
||||
"resolver/passthrough",
|
||||
"stats",
|
||||
"status",
|
||||
"tap",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "df014850f6dee74ba2fc94874043a9f3f75fbfd8"
|
||||
version = "v1.17.0"
|
||||
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
@ -331,9 +407,8 @@
|
||||
"github.com/kr/pretty",
|
||||
"github.com/mattn/go-isatty",
|
||||
"github.com/pkg/errors",
|
||||
"github.com/pkg/profile",
|
||||
"github.com/problame/go-netssh",
|
||||
"github.com/problame/go-rwccmd",
|
||||
"github.com/problame/go-streamrpc",
|
||||
"github.com/prometheus/client_golang/prometheus",
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp",
|
||||
"github.com/spf13/cobra",
|
||||
@ -341,8 +416,13 @@
|
||||
"github.com/stretchr/testify/assert",
|
||||
"github.com/stretchr/testify/require",
|
||||
"github.com/zrepl/yaml-config",
|
||||
"golang.org/x/net/context",
|
||||
"golang.org/x/sys/unix",
|
||||
"golang.org/x/tools/cmd/stringer",
|
||||
"google.golang.org/grpc",
|
||||
"google.golang.org/grpc/credentials",
|
||||
"google.golang.org/grpc/keepalive",
|
||||
"google.golang.org/grpc/peer",
|
||||
]
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
42
Gopkg.toml
42
Gopkg.toml
@ -1,8 +1,11 @@
|
||||
ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
ignored = [
|
||||
"github.com/inconshreveable/mousetrap",
|
||||
]
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/ftrvxmtrx/fd"
|
||||
required = [
|
||||
"golang.org/x/tools/cmd/stringer",
|
||||
"github.com/alvaroloes/enumer",
|
||||
]
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
@ -12,14 +15,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
branch = "master"
|
||||
name = "github.com/kr/pretty"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/mitchellh/go-homedir"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/mitchellh/mapstructure"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/pkg/errors"
|
||||
version = "0.8.0"
|
||||
@ -28,10 +23,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
branch = "master"
|
||||
name = "github.com/spf13/cobra"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/spf13/viper"
|
||||
version = "1.0.0"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/stretchr/testify"
|
||||
version = "1.1.4"
|
||||
@ -44,10 +35,6 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
name = "github.com/go-logfmt/logfmt"
|
||||
version = "*"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/problame/go-rwccmd"
|
||||
branch = "master"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/problame/go-netssh"
|
||||
branch = "master"
|
||||
@ -58,26 +45,17 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/golang/protobuf"
|
||||
version = "1.2.0"
|
||||
version = "1"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/fatih/color"
|
||||
version = "1.7.0"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/problame/go-streamrpc"
|
||||
version = "0.5.0"
|
||||
|
||||
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/gdamore/tcell"
|
||||
version = "1.0.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/tools"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "github.com/alvaroloes/enumer"
|
||||
name = "google.golang.org/grpc"
|
||||
version = "1"
|
||||
|
64
Makefile
64
Makefile
@ -1,45 +1,13 @@
|
||||
.PHONY: generate build test vet cover release docs docs-clean clean vendordeps
|
||||
.DEFAULT_GOAL := build
|
||||
|
||||
ROOT := github.com/zrepl/zrepl
|
||||
SUBPKGS += client
|
||||
SUBPKGS += config
|
||||
SUBPKGS += daemon
|
||||
SUBPKGS += daemon/filters
|
||||
SUBPKGS += daemon/job
|
||||
SUBPKGS += daemon/logging
|
||||
SUBPKGS += daemon/nethelpers
|
||||
SUBPKGS += daemon/pruner
|
||||
SUBPKGS += daemon/snapper
|
||||
SUBPKGS += daemon/streamrpcconfig
|
||||
SUBPKGS += daemon/transport
|
||||
SUBPKGS += daemon/transport/connecter
|
||||
SUBPKGS += daemon/transport/serve
|
||||
SUBPKGS += endpoint
|
||||
SUBPKGS += logger
|
||||
SUBPKGS += pruning
|
||||
SUBPKGS += pruning/retentiongrid
|
||||
SUBPKGS += replication
|
||||
SUBPKGS += replication/fsrep
|
||||
SUBPKGS += replication/pdu
|
||||
SUBPKGS += replication/internal/diff
|
||||
SUBPKGS += tlsconf
|
||||
SUBPKGS += util
|
||||
SUBPKGS += util/socketpair
|
||||
SUBPKGS += util/watchdog
|
||||
SUBPKGS += util/envconst
|
||||
SUBPKGS += version
|
||||
SUBPKGS += zfs
|
||||
|
||||
_TESTPKGS := $(ROOT) $(foreach p,$(SUBPKGS),$(ROOT)/$(p))
|
||||
|
||||
ARTIFACTDIR := artifacts
|
||||
|
||||
ifdef ZREPL_VERSION
|
||||
_ZREPL_VERSION := $(ZREPL_VERSION)
|
||||
endif
|
||||
ifndef _ZREPL_VERSION
|
||||
_ZREPL_VERSION := $(shell git describe --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" )
|
||||
_ZREPL_VERSION := $(shell git describe --always --dirty 2>/dev/null || echo "ZREPL_BUILD_INVALID_VERSION" )
|
||||
ifeq ($(_ZREPL_VERSION),ZREPL_BUILD_INVALID_VERSION) # can't use .SHELLSTATUS because Debian Stretch is still on gmake 4.1
|
||||
$(error cannot infer variable ZREPL_VERSION using git and variable is not overriden by make invocation)
|
||||
endif
|
||||
@ -59,35 +27,21 @@ vendordeps:
|
||||
dep ensure -v -vendor-only
|
||||
|
||||
generate: #not part of the build, must do that manually
|
||||
protoc -I=replication/pdu --go_out=replication/pdu replication/pdu/pdu.proto
|
||||
@for pkg in $(_TESTPKGS); do\
|
||||
go generate "$$pkg" || exit 1; \
|
||||
done;
|
||||
protoc -I=replication/pdu --go_out=plugins=grpc:replication/pdu replication/pdu/pdu.proto
|
||||
go generate -x ./...
|
||||
|
||||
build:
|
||||
@echo "INFO: In case of missing dependencies, run 'make vendordeps'"
|
||||
$(GO_BUILD) -o "$(ARTIFACTDIR)/zrepl"
|
||||
|
||||
test:
|
||||
@for pkg in $(_TESTPKGS); do \
|
||||
echo "Testing $$pkg"; \
|
||||
go test "$$pkg" || exit 1; \
|
||||
done;
|
||||
go test ./...
|
||||
|
||||
vet:
|
||||
@for pkg in $(_TESTPKGS); do \
|
||||
echo "Vetting $$pkg"; \
|
||||
go vet "$$pkg" || exit 1; \
|
||||
done;
|
||||
|
||||
cover: artifacts
|
||||
@for pkg in $(_TESTPKGS); do \
|
||||
profile="$(ARTIFACTDIR)/cover-$$(basename $$pkg).out"; \
|
||||
go test -coverprofile "$$profile" $$pkg || exit 1; \
|
||||
if [ -f "$$profile" ]; then \
|
||||
go tool cover -html="$$profile" -o "$${profile}.html" || exit 2; \
|
||||
fi; \
|
||||
done;
|
||||
# for each supported platform to cover conditional compilation
|
||||
GOOS=linux go vet ./...
|
||||
GOOS=darwin go vet ./...
|
||||
GOOS=freebsd go vet ./...
|
||||
|
||||
$(ARTIFACTDIR):
|
||||
mkdir -p "$@"
|
||||
@ -132,7 +86,7 @@ release: $(RELEASE_BINS) $(RELEASE_NOARCH)
|
||||
cp $^ "$(ARTIFACTDIR)/release"
|
||||
cd "$(ARTIFACTDIR)/release" && sha512sum $$(ls | sort) > sha512sum.txt
|
||||
@# note that we use ZREPL_VERSION and not _ZREPL_VERSION because we want to detect the override
|
||||
@if git describe --dirty 2>/dev/null | grep dirty >/dev/null; then \
|
||||
@if git describe --always --dirty 2>/dev/null | grep dirty >/dev/null; then \
|
||||
echo '[INFO] either git reports checkout is dirty or git is not installed or this is not a git checkout'; \
|
||||
if [ "$(ZREPL_VERSION)" = "" ]; then \
|
||||
echo '[WARN] git checkout is dirty and make variable ZREPL_VERSION was not used to override'; \
|
||||
|
@ -130,7 +130,6 @@ type Global struct {
|
||||
Monitoring []MonitoringEnum `yaml:"monitoring,optional"`
|
||||
Control *GlobalControl `yaml:"control,optional,fromdefaults"`
|
||||
Serve *GlobalServe `yaml:"serve,optional,fromdefaults"`
|
||||
RPC *RPCConfig `yaml:"rpc,optional,fromdefaults"`
|
||||
}
|
||||
|
||||
func Default(i interface{}) {
|
||||
@ -145,29 +144,18 @@ func Default(i interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
type RPCConfig struct {
|
||||
Timeout time.Duration `yaml:"timeout,optional,positive,default=10s"`
|
||||
TxChunkSize uint32 `yaml:"tx_chunk_size,optional,default=32768"`
|
||||
RxStructuredMaxLen uint32 `yaml:"rx_structured_max,optional,default=16777216"`
|
||||
RxStreamChunkMaxLen uint32 `yaml:"rx_stream_chunk_max,optional,default=16777216"`
|
||||
RxHeaderMaxLen uint32 `yaml:"rx_header_max,optional,default=40960"`
|
||||
SendHeartbeatInterval time.Duration `yaml:"send_heartbeat_interval,optional,positive,default=5s"`
|
||||
|
||||
}
|
||||
|
||||
type ConnectEnum struct {
|
||||
Ret interface{}
|
||||
}
|
||||
|
||||
type ConnectCommon struct {
|
||||
Type string `yaml:"type"`
|
||||
RPC *RPCConfig `yaml:"rpc,optional"`
|
||||
}
|
||||
|
||||
type TCPConnect struct {
|
||||
ConnectCommon `yaml:",inline"`
|
||||
Address string `yaml:"address"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
|
||||
}
|
||||
|
||||
type TLSConnect struct {
|
||||
@ -177,7 +165,7 @@ type TLSConnect struct {
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
ServerCN string `yaml:"server_cn"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
|
||||
}
|
||||
|
||||
type SSHStdinserverConnect struct {
|
||||
@ -189,7 +177,7 @@ type SSHStdinserverConnect struct {
|
||||
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
|
||||
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
|
||||
Options []string `yaml:"options,optional"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,zeropositive,default=10s"`
|
||||
}
|
||||
|
||||
type LocalConnect struct {
|
||||
@ -204,7 +192,6 @@ type ServeEnum struct {
|
||||
|
||||
type ServeCommon struct {
|
||||
Type string `yaml:"type"`
|
||||
RPC *RPCConfig `yaml:"rpc,optional"`
|
||||
}
|
||||
|
||||
type TCPServe struct {
|
||||
@ -220,7 +207,7 @@ type TLSServe struct {
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
ClientCNs []string `yaml:"client_cns"`
|
||||
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
|
||||
HandshakeTimeout time.Duration `yaml:"handshake_timeout,zeropositive,default=10s"`
|
||||
}
|
||||
|
||||
type StdinserverServer struct {
|
||||
|
@ -1,86 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRPC(t *testing.T) {
|
||||
conf := testValidConfig(t, `
|
||||
jobs:
|
||||
- name: pull_servers
|
||||
type: pull
|
||||
connect:
|
||||
type: tcp
|
||||
address: "server1.foo.bar:8888"
|
||||
rpc:
|
||||
timeout: 20s # different form default, should merge
|
||||
root_fs: "pool2/backup_servers"
|
||||
interval: 10m
|
||||
pruning:
|
||||
keep_sender:
|
||||
- type: not_replicated
|
||||
keep_receiver:
|
||||
- type: last_n
|
||||
count: 100
|
||||
|
||||
- name: pull_servers2
|
||||
type: pull
|
||||
connect:
|
||||
type: tcp
|
||||
address: "server1.foo.bar:8888"
|
||||
rpc:
|
||||
tx_chunk_size: 0xabcd # different from default, should merge
|
||||
root_fs: "pool2/backup_servers"
|
||||
interval: 10m
|
||||
pruning:
|
||||
keep_sender:
|
||||
- type: not_replicated
|
||||
keep_receiver:
|
||||
- type: last_n
|
||||
count: 100
|
||||
|
||||
- type: sink
|
||||
name: "laptop_sink"
|
||||
root_fs: "pool2/backup_laptops"
|
||||
serve:
|
||||
type: tcp
|
||||
listen: "192.168.122.189:8888"
|
||||
clients: {
|
||||
"10.23.42.23":"client1"
|
||||
}
|
||||
rpc:
|
||||
rx_structured_max: 0x2342
|
||||
|
||||
- type: sink
|
||||
name: "other_sink"
|
||||
root_fs: "pool2/backup_laptops"
|
||||
serve:
|
||||
type: tcp
|
||||
listen: "192.168.122.189:8888"
|
||||
clients: {
|
||||
"10.23.42.23":"client1"
|
||||
}
|
||||
rpc:
|
||||
send_heartbeat_interval: 10s
|
||||
|
||||
`)
|
||||
|
||||
assert.Equal(t, 20*time.Second, conf.Jobs[0].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.Timeout)
|
||||
assert.Equal(t, uint32(0xabcd), conf.Jobs[1].Ret.(*PullJob).Connect.Ret.(*TCPConnect).RPC.TxChunkSize)
|
||||
assert.Equal(t, uint32(0x2342), conf.Jobs[2].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.RxStructuredMaxLen)
|
||||
assert.Equal(t, 10*time.Second, conf.Jobs[3].Ret.(*SinkJob).Serve.Ret.(*TCPServe).RPC.SendHeartbeatInterval)
|
||||
defConf := RPCConfig{}
|
||||
Default(&defConf)
|
||||
assert.Equal(t, defConf.Timeout, conf.Global.RPC.Timeout)
|
||||
}
|
||||
|
||||
func TestGlobal_DefaultRPCConfig(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
var c RPCConfig
|
||||
Default(&c)
|
||||
assert.NotNil(t, c)
|
||||
assert.Equal(t, c.TxChunkSize, uint32(1)<<15)
|
||||
})
|
||||
}
|
@ -1,11 +1,14 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/kr/pretty"
|
||||
"github.com/stretchr/testify/require"
|
||||
"bytes"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/kr/pretty"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) {
|
||||
@ -35,8 +38,21 @@ func TestSampleConfigsAreParsedWithoutErrors(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
// template must be a template/text template with a single '{{ . }}' as placehodler for val
|
||||
func testValidConfigTemplate(t *testing.T, tmpl string, val string) *Config {
|
||||
tmp, err := template.New("master").Parse(tmpl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
err = tmp.Execute(&buf, val)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return testValidConfig(t, buf.String())
|
||||
}
|
||||
|
||||
func testValidConfig(t *testing.T, input string) (*Config) {
|
||||
func testValidConfig(t *testing.T, input string) *Config {
|
||||
t.Helper()
|
||||
conf, err := testConfig(t, input)
|
||||
require.NoError(t, err)
|
||||
|
@ -2,9 +2,12 @@ package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/filters"
|
||||
"github.com/zrepl/zrepl/daemon/job/reset"
|
||||
@ -12,23 +15,22 @@ import (
|
||||
"github.com/zrepl/zrepl/daemon/logging"
|
||||
"github.com/zrepl/zrepl/daemon/pruner"
|
||||
"github.com/zrepl/zrepl/daemon/snapper"
|
||||
"github.com/zrepl/zrepl/daemon/transport/connecter"
|
||||
"github.com/zrepl/zrepl/endpoint"
|
||||
"github.com/zrepl/zrepl/replication"
|
||||
"github.com/zrepl/zrepl/rpc"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/transport/fromconfig"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ActiveSide struct {
|
||||
mode activeMode
|
||||
name string
|
||||
clientFactory *connecter.ClientFactory
|
||||
connecter transport.Connecter
|
||||
|
||||
prunerFactory *pruner.PrunerFactory
|
||||
|
||||
|
||||
promRepStateSecs *prometheus.HistogramVec // labels: state
|
||||
promPruneSecs *prometheus.HistogramVec // labels: prune_side
|
||||
promBytesReplicated *prometheus.CounterVec // labels: filesystem
|
||||
@ -37,7 +39,6 @@ type ActiveSide struct {
|
||||
tasks activeSideTasks
|
||||
}
|
||||
|
||||
|
||||
//go:generate enumer -type=ActiveSideState
|
||||
type ActiveSideState int
|
||||
|
||||
@ -48,7 +49,6 @@ const (
|
||||
ActiveSideDone // also errors
|
||||
)
|
||||
|
||||
|
||||
type activeSideTasks struct {
|
||||
state ActiveSideState
|
||||
|
||||
@ -77,20 +77,44 @@ func (a *ActiveSide) updateTasks(u func(*activeSideTasks)) activeSideTasks {
|
||||
}
|
||||
|
||||
type activeMode interface {
|
||||
SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error)
|
||||
ConnectEndpoints(rpcLoggers rpc.Loggers, connecter transport.Connecter)
|
||||
DisconnectEndpoints()
|
||||
SenderReceiver() (replication.Sender, replication.Receiver)
|
||||
Type() Type
|
||||
RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{})
|
||||
ResetConnectBackoff()
|
||||
}
|
||||
|
||||
type modePush struct {
|
||||
setupMtx sync.Mutex
|
||||
sender *endpoint.Sender
|
||||
receiver *rpc.Client
|
||||
fsfilter endpoint.FSFilter
|
||||
snapper *snapper.PeriodicOrManual
|
||||
}
|
||||
|
||||
func (m *modePush) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) {
|
||||
sender := endpoint.NewSender(m.fsfilter)
|
||||
receiver := endpoint.NewRemote(client)
|
||||
return sender, receiver, nil
|
||||
func (m *modePush) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
if m.receiver != nil || m.sender != nil {
|
||||
panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints")
|
||||
}
|
||||
m.sender = endpoint.NewSender(m.fsfilter)
|
||||
m.receiver = rpc.NewClient(connecter, loggers)
|
||||
}
|
||||
|
||||
func (m *modePush) DisconnectEndpoints() {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
m.receiver.Close()
|
||||
m.sender = nil
|
||||
m.receiver = nil
|
||||
}
|
||||
|
||||
func (m *modePush) SenderReceiver() (replication.Sender, replication.Receiver) {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
return m.sender, m.receiver
|
||||
}
|
||||
|
||||
func (m *modePush) Type() Type { return TypePush }
|
||||
@ -99,6 +123,13 @@ func (m *modePush) RunPeriodic(ctx context.Context, wakeUpCommon chan <- struct{
|
||||
m.snapper.Run(ctx, wakeUpCommon)
|
||||
}
|
||||
|
||||
func (m *modePush) ResetConnectBackoff() {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
if m.receiver != nil {
|
||||
m.receiver.ResetConnectBackoff()
|
||||
}
|
||||
}
|
||||
|
||||
func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error) {
|
||||
m := &modePush{}
|
||||
@ -116,14 +147,35 @@ func modePushFromConfig(g *config.Global, in *config.PushJob) (*modePush, error)
|
||||
}
|
||||
|
||||
type modePull struct {
|
||||
setupMtx sync.Mutex
|
||||
receiver *endpoint.Receiver
|
||||
sender *rpc.Client
|
||||
rootFS *zfs.DatasetPath
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func (m *modePull) SenderReceiver(client *streamrpc.Client) (replication.Sender, replication.Receiver, error) {
|
||||
sender := endpoint.NewRemote(client)
|
||||
receiver, err := endpoint.NewReceiver(m.rootFS)
|
||||
return sender, receiver, err
|
||||
func (m *modePull) ConnectEndpoints(loggers rpc.Loggers, connecter transport.Connecter) {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
if m.receiver != nil || m.sender != nil {
|
||||
panic("inconsistent use of ConnectEndpoints and DisconnectEndpoints")
|
||||
}
|
||||
m.receiver = endpoint.NewReceiver(m.rootFS, false)
|
||||
m.sender = rpc.NewClient(connecter, loggers)
|
||||
}
|
||||
|
||||
func (m *modePull) DisconnectEndpoints() {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
m.sender.Close()
|
||||
m.sender = nil
|
||||
m.receiver = nil
|
||||
}
|
||||
|
||||
func (m *modePull) SenderReceiver() (replication.Sender, replication.Receiver) {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
return m.sender, m.receiver
|
||||
}
|
||||
|
||||
func (*modePull) Type() Type { return TypePull }
|
||||
@ -148,6 +200,14 @@ func (m *modePull) RunPeriodic(ctx context.Context, wakeUpCommon chan<- struct{}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *modePull) ResetConnectBackoff() {
|
||||
m.setupMtx.Lock()
|
||||
defer m.setupMtx.Unlock()
|
||||
if m.sender != nil {
|
||||
m.sender.ResetConnectBackoff()
|
||||
}
|
||||
}
|
||||
|
||||
func modePullFromConfig(g *config.Global, in *config.PullJob) (m *modePull, err error) {
|
||||
m = &modePull{}
|
||||
if in.Interval <= 0 {
|
||||
@ -185,7 +245,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, mode activeMode) (j *Act
|
||||
ConstLabels: prometheus.Labels{"zrepl_job": j.name},
|
||||
}, []string{"filesystem"})
|
||||
|
||||
j.clientFactory, err = connecter.FromConfig(g, in.Connect)
|
||||
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot build client")
|
||||
}
|
||||
@ -256,6 +316,7 @@ outer:
|
||||
break outer
|
||||
|
||||
case <-wakeup.Wait(ctx):
|
||||
j.mode.ResetConnectBackoff()
|
||||
case <-periodicDone:
|
||||
}
|
||||
invocationCount++
|
||||
@ -268,6 +329,9 @@ func (j *ActiveSide) do(ctx context.Context) {
|
||||
|
||||
log := GetLogger(ctx)
|
||||
ctx = logging.WithSubsystemLoggers(ctx, log)
|
||||
loggers := rpc.GetLoggersOrPanic(ctx) // filled by WithSubsystemLoggers
|
||||
j.mode.ConnectEndpoints(loggers, j.connecter)
|
||||
defer j.mode.DisconnectEndpoints()
|
||||
|
||||
// allow cancellation of an invocation (this function)
|
||||
ctx, cancelThisRun := context.WithCancel(ctx)
|
||||
@ -353,13 +417,7 @@ func (j *ActiveSide) do(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := j.clientFactory.NewClient()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("factory cannot instantiate streamrpc client")
|
||||
}
|
||||
defer client.Close(ctx)
|
||||
|
||||
sender, receiver, err := j.mode.SenderReceiver(client)
|
||||
sender, receiver := j.mode.SenderReceiver()
|
||||
|
||||
{
|
||||
select {
|
||||
|
@ -2,28 +2,30 @@ package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/filters"
|
||||
"github.com/zrepl/zrepl/daemon/logging"
|
||||
"github.com/zrepl/zrepl/daemon/transport/serve"
|
||||
"github.com/zrepl/zrepl/daemon/snapper"
|
||||
"github.com/zrepl/zrepl/endpoint"
|
||||
"github.com/zrepl/zrepl/rpc"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/transport/fromconfig"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"path"
|
||||
)
|
||||
|
||||
type PassiveSide struct {
|
||||
mode passiveMode
|
||||
name string
|
||||
l serve.ListenerFactory
|
||||
rpcConf *streamrpc.ConnConfig
|
||||
listen transport.AuthenticatedListenerFactory
|
||||
}
|
||||
|
||||
type passiveMode interface {
|
||||
ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc
|
||||
Handler() rpc.Handler
|
||||
RunPeriodic(ctx context.Context)
|
||||
Type() Type
|
||||
}
|
||||
@ -34,26 +36,8 @@ type modeSink struct {
|
||||
|
||||
func (m *modeSink) Type() Type { return TypeSink }
|
||||
|
||||
func (m *modeSink) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc {
|
||||
log := GetLogger(ctx)
|
||||
|
||||
clientRootStr := path.Join(m.rootDataset.ToString(), conn.ClientIdentity())
|
||||
clientRoot, err := zfs.NewDatasetPath(clientRootStr)
|
||||
if err != nil {
|
||||
log.WithError(err).
|
||||
WithField("client_identity", conn.ClientIdentity()).
|
||||
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
|
||||
}
|
||||
log.WithField("client_root", clientRoot).Debug("client root")
|
||||
|
||||
local, err := endpoint.NewReceiver(clientRoot)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
|
||||
return nil
|
||||
}
|
||||
|
||||
h := endpoint.NewHandler(local)
|
||||
return h.Handle
|
||||
func (m *modeSink) Handler() rpc.Handler {
|
||||
return endpoint.NewReceiver(m.rootDataset, true)
|
||||
}
|
||||
|
||||
func (m *modeSink) RunPeriodic(_ context.Context) {}
|
||||
@ -93,10 +77,8 @@ func modeSourceFromConfig(g *config.Global, in *config.SourceJob) (m *modeSource
|
||||
|
||||
func (m *modeSource) Type() Type { return TypeSource }
|
||||
|
||||
func (m *modeSource) ConnHandleFunc(ctx context.Context, conn serve.AuthenticatedConn) streamrpc.HandlerFunc {
|
||||
sender := endpoint.NewSender(m.fsfilter)
|
||||
h := endpoint.NewHandler(sender)
|
||||
return h.Handle
|
||||
func (m *modeSource) Handler() rpc.Handler {
|
||||
return endpoint.NewSender(m.fsfilter)
|
||||
}
|
||||
|
||||
func (m *modeSource) RunPeriodic(ctx context.Context) {
|
||||
@ -106,8 +88,8 @@ func (m *modeSource) RunPeriodic(ctx context.Context) {
|
||||
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, mode passiveMode) (s *PassiveSide, err error) {
|
||||
|
||||
s = &PassiveSide{mode: mode, name: in.Name}
|
||||
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
|
||||
return nil, errors.Wrap(err, "cannot build server")
|
||||
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil {
|
||||
return nil, errors.Wrap(err, "cannot build listener factory")
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@ -127,70 +109,30 @@ func (j *PassiveSide) Run(ctx context.Context) {
|
||||
|
||||
log := GetLogger(ctx)
|
||||
defer log.Info("job exiting")
|
||||
|
||||
l, err := j.l.Listen()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("cannot listen")
|
||||
return
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
ctx = logging.WithSubsystemLoggers(ctx, log)
|
||||
{
|
||||
ctx, cancel := context.WithCancel(logging.WithSubsystemLoggers(ctx, log)) // shadowing
|
||||
ctx, cancel := context.WithCancel(ctx) // shadowing
|
||||
defer cancel()
|
||||
go j.mode.RunPeriodic(ctx)
|
||||
}
|
||||
|
||||
log.WithField("addr", l.Addr()).Debug("accepting connections")
|
||||
var connId int
|
||||
outer:
|
||||
for {
|
||||
|
||||
select {
|
||||
case res := <-accept(ctx, l):
|
||||
if res.err != nil {
|
||||
log.WithError(res.err).Info("accept error")
|
||||
continue
|
||||
handler := j.mode.Handler()
|
||||
if handler == nil {
|
||||
panic(fmt.Sprintf("implementation error: j.mode.Handler() returned nil: %#v", j))
|
||||
}
|
||||
conn := res.conn
|
||||
connId++
|
||||
connLog := log.
|
||||
WithField("connID", connId)
|
||||
connLog.
|
||||
WithField("addr", conn.RemoteAddr()).
|
||||
WithField("client_identity", conn.ClientIdentity()).
|
||||
Info("handling connection")
|
||||
go func() {
|
||||
defer connLog.Info("finished handling connection")
|
||||
defer conn.Close()
|
||||
ctx := logging.WithSubsystemLoggers(ctx, connLog)
|
||||
handleFunc := j.mode.ConnHandleFunc(ctx, conn)
|
||||
if handleFunc == nil {
|
||||
|
||||
ctxInterceptor := func(handlerCtx context.Context) context.Context {
|
||||
return logging.WithSubsystemLoggers(handlerCtx, log)
|
||||
}
|
||||
|
||||
rpcLoggers := rpc.GetLoggersOrPanic(ctx) // WithSubsystemLoggers above
|
||||
server := rpc.NewServer(handler, rpcLoggers, ctxInterceptor)
|
||||
|
||||
listener, err := j.listen()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("cannot listen")
|
||||
return
|
||||
}
|
||||
if err := streamrpc.ServeConn(ctx, conn, j.rpcConf, handleFunc); err != nil {
|
||||
log.WithError(err).Error("error serving client")
|
||||
}
|
||||
}()
|
||||
|
||||
case <-ctx.Done():
|
||||
break outer
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type acceptResult struct {
|
||||
conn serve.AuthenticatedConn
|
||||
err error
|
||||
}
|
||||
|
||||
func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
|
||||
c := make(chan acceptResult, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept(ctx)
|
||||
c <- acceptResult{conn, err}
|
||||
}()
|
||||
return c
|
||||
server.Serve(ctx, listener)
|
||||
}
|
||||
|
@ -1,32 +0,0 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type streamrpcLogAdaptor = twoClassLogAdaptor
|
||||
|
||||
type twoClassLogAdaptor struct {
|
||||
logger.Logger
|
||||
}
|
||||
|
||||
var _ streamrpc.Logger = twoClassLogAdaptor{}
|
||||
|
||||
func (a twoClassLogAdaptor) Errorf(fmtStr string, args ...interface{}) {
|
||||
const errorSuffix = ": %s"
|
||||
if len(args) == 1 {
|
||||
if err, ok := args[0].(error); ok && strings.HasSuffix(fmtStr, errorSuffix) {
|
||||
msg := strings.TrimSuffix(fmtStr, errorSuffix)
|
||||
a.WithError(err).Error(msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
a.Logger.Error(fmt.Sprintf(fmtStr, args...))
|
||||
}
|
||||
|
||||
func (a twoClassLogAdaptor) Infof(fmtStr string, args ...interface{}) {
|
||||
a.Logger.Debug(fmt.Sprintf(fmtStr, args...))
|
||||
}
|
@ -4,18 +4,21 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-streamrpc"
|
||||
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/pruner"
|
||||
"github.com/zrepl/zrepl/daemon/snapper"
|
||||
"github.com/zrepl/zrepl/endpoint"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/replication"
|
||||
"github.com/zrepl/zrepl/rpc"
|
||||
"github.com/zrepl/zrepl/rpc/transportmux"
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"os"
|
||||
"github.com/zrepl/zrepl/daemon/snapper"
|
||||
"github.com/zrepl/zrepl/daemon/transport/serve"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
|
||||
@ -60,22 +63,41 @@ func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error)
|
||||
|
||||
}
|
||||
|
||||
type Subsystem string
|
||||
|
||||
const (
|
||||
SubsysReplication = "repl"
|
||||
SubsysStreamrpc = "rpc"
|
||||
SubsyEndpoint = "endpoint"
|
||||
SubsysReplication Subsystem = "repl"
|
||||
SubsyEndpoint Subsystem = "endpoint"
|
||||
SubsysPruning Subsystem = "pruning"
|
||||
SubsysSnapshot Subsystem = "snapshot"
|
||||
SubsysTransport Subsystem = "transport"
|
||||
SubsysTransportMux Subsystem = "transportmux"
|
||||
SubsysRPC Subsystem = "rpc"
|
||||
SubsysRPCControl Subsystem = "rpc.ctrl"
|
||||
SubsysRPCData Subsystem = "rpc.data"
|
||||
)
|
||||
|
||||
func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Context {
|
||||
ctx = replication.WithLogger(ctx, log.WithField(SubsysField, "repl"))
|
||||
ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{log.WithField(SubsysField, "rpc")})
|
||||
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
|
||||
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
|
||||
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
|
||||
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve"))
|
||||
ctx = replication.WithLogger(ctx, log.WithField(SubsysField, SubsysReplication))
|
||||
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, SubsyEndpoint))
|
||||
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, SubsysPruning))
|
||||
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, SubsysSnapshot))
|
||||
ctx = transport.WithLogger(ctx, log.WithField(SubsysField, SubsysTransport))
|
||||
ctx = transportmux.WithLogger(ctx, log.WithField(SubsysField, SubsysTransportMux))
|
||||
ctx = rpc.WithLoggers(ctx,
|
||||
rpc.Loggers{
|
||||
General: log.WithField(SubsysField, SubsysRPC),
|
||||
Control: log.WithField(SubsysField, SubsysRPCControl),
|
||||
Data: log.WithField(SubsysField, SubsysRPCData),
|
||||
},
|
||||
)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func LogSubsystem(log logger.Logger, subsys Subsystem) logger.Logger {
|
||||
return log.ReplaceField(SubsysField, subsys)
|
||||
}
|
||||
|
||||
func parseLogFormat(i interface{}) (f EntryFormatter, err error) {
|
||||
var is string
|
||||
switch j := i.(type) {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/job"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -49,6 +50,10 @@ func (j *prometheusJob) Run(ctx context.Context) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := frameconn.PrometheusRegister(prometheus.DefaultRegisterer); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log := job.GetLogger(ctx)
|
||||
|
||||
l, err := net.Listen("tcp", j.listen)
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"github.com/zrepl/zrepl/util/watchdog"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -19,14 +18,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Try to keep it compatible with gitub.com/zrepl/zrepl/replication.Endpoint
|
||||
// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint
|
||||
type History interface {
|
||||
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
|
||||
}
|
||||
|
||||
// Try to keep it compatible with gitub.com/zrepl/zrepl/endpoint.Endpoint
|
||||
type Target interface {
|
||||
ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error)
|
||||
ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS
|
||||
ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error)
|
||||
ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error)
|
||||
DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error)
|
||||
}
|
||||
|
||||
@ -346,7 +346,6 @@ type Error interface {
|
||||
}
|
||||
|
||||
var _ Error = net.Error(nil)
|
||||
var _ Error = streamrpc.Error(nil)
|
||||
|
||||
func shouldRetry(e error) bool {
|
||||
if neterr, ok := e.(net.Error); ok {
|
||||
@ -381,10 +380,11 @@ func statePlan(a *args, u updater) state {
|
||||
ka = &pruner.Progress
|
||||
})
|
||||
|
||||
tfss, err := target.ListFilesystems(ctx)
|
||||
tfssres, err := target.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
|
||||
if err != nil {
|
||||
return onErr(u, err)
|
||||
}
|
||||
tfss := tfssres.GetFilesystems()
|
||||
|
||||
pfss := make([]*fs, len(tfss))
|
||||
for i, tfs := range tfss {
|
||||
@ -398,11 +398,12 @@ func statePlan(a *args, u updater) state {
|
||||
}
|
||||
pfss[i] = pfs
|
||||
|
||||
tfsvs, err := target.ListFilesystemVersions(ctx, tfs.Path)
|
||||
tfsvsres, err := target.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: tfs.Path})
|
||||
if err != nil {
|
||||
l.WithError(err).Error("cannot list filesystem versions")
|
||||
return onErr(u, err)
|
||||
}
|
||||
tfsvs := tfsvsres.GetVersions()
|
||||
// no progress here since we could run in a live-lock (must have used target AND receiver before progress)
|
||||
|
||||
pfs.snaps = make([]pruning.Snapshot, 0, len(tfsvs))
|
||||
|
@ -44,7 +44,7 @@ type mockTarget struct {
|
||||
destroyErrs map[string][]error
|
||||
}
|
||||
|
||||
func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) {
|
||||
func (t *mockTarget) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
|
||||
if len(t.listFilesystemsErr) > 0 {
|
||||
e := t.listFilesystemsErr[0]
|
||||
t.listFilesystemsErr = t.listFilesystemsErr[1:]
|
||||
@ -54,10 +54,11 @@ func (t *mockTarget) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, er
|
||||
for i := range fss {
|
||||
fss[i] = t.fss[i].Filesystem()
|
||||
}
|
||||
return fss, nil
|
||||
return &pdu.ListFilesystemRes{Filesystems: fss}, nil
|
||||
}
|
||||
|
||||
func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) {
|
||||
func (t *mockTarget) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
|
||||
fs := req.Filesystem
|
||||
if len(t.listVersionsErrs[fs]) != 0 {
|
||||
e := t.listVersionsErrs[fs][0]
|
||||
t.listVersionsErrs[fs] = t.listVersionsErrs[fs][1:]
|
||||
@ -68,7 +69,7 @@ func (t *mockTarget) ListFilesystemVersions(ctx context.Context, fs string) ([]*
|
||||
if mfs.path != fs {
|
||||
continue
|
||||
}
|
||||
return mfs.FilesystemVersions(), nil
|
||||
return &pdu.ListFilesystemVersionsRes{Versions: mfs.FilesystemVersions()}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("filesystem %s does not exist", fs)
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ func onMainCtxDone(ctx context.Context, u updater) state {
|
||||
}
|
||||
|
||||
func syncUp(a args, u updater) state {
|
||||
fss, err := listFSes(a.fsf)
|
||||
fss, err := listFSes(a.ctx, a.fsf)
|
||||
if err != nil {
|
||||
return onErr(err, u)
|
||||
}
|
||||
@ -204,7 +204,7 @@ func plan(a args, u updater) state {
|
||||
u(func(snapper *Snapper) {
|
||||
snapper.lastInvocation = time.Now()
|
||||
})
|
||||
fss, err := listFSes(a.fsf)
|
||||
fss, err := listFSes(a.ctx, a.fsf)
|
||||
if err != nil {
|
||||
return onErr(err, u)
|
||||
}
|
||||
@ -299,8 +299,8 @@ func wait(a args, u updater) state {
|
||||
}
|
||||
}
|
||||
|
||||
func listFSes(mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) {
|
||||
return zfs.ZFSListMapping(mf)
|
||||
func listFSes(ctx context.Context, mf *filters.DatasetMapFilter) (fss []*zfs.DatasetPath, err error) {
|
||||
return zfs.ZFSListMapping(ctx, mf)
|
||||
}
|
||||
|
||||
func findSyncPoint(log Logger, fss []*zfs.DatasetPath, prefix string, interval time.Duration) (syncPoint time.Time, err error) {
|
||||
|
@ -1,25 +0,0 @@
|
||||
package streamrpcconfig
|
||||
|
||||
import (
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
)
|
||||
|
||||
func FromDaemonConfig(g *config.Global, in *config.RPCConfig) (*streamrpc.ConnConfig, error) {
|
||||
conf := in
|
||||
if conf == nil {
|
||||
conf = g.RPC
|
||||
}
|
||||
srpcConf := &streamrpc.ConnConfig{
|
||||
RxHeaderMaxLen: conf.RxHeaderMaxLen,
|
||||
RxStructuredMaxLen: conf.RxStructuredMaxLen,
|
||||
RxStreamMaxChunkSize: conf.RxStreamChunkMaxLen,
|
||||
TxChunkSize: conf.TxChunkSize,
|
||||
Timeout: conf.Timeout,
|
||||
SendHeartbeatInterval: conf.SendHeartbeatInterval,
|
||||
}
|
||||
if err := srpcConf.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srpcConf, nil
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
package connecter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||
"github.com/zrepl/zrepl/daemon/transport"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
type HandshakeConnecter struct {
|
||||
connecter streamrpc.Connecter
|
||||
}
|
||||
|
||||
func (c HandshakeConnecter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
conn, err := c.connecter.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(10 * time.Second) // FIXME constant
|
||||
}
|
||||
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
func FromConfig(g *config.Global, in config.ConnectEnum) (*ClientFactory, error) {
|
||||
var (
|
||||
connecter streamrpc.Connecter
|
||||
errConnecter, errRPC error
|
||||
connConf *streamrpc.ConnConfig
|
||||
)
|
||||
switch v := in.Ret.(type) {
|
||||
case *config.SSHStdinserverConnect:
|
||||
connecter, errConnecter = SSHStdinserverConnecterFromConfig(v)
|
||||
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.TCPConnect:
|
||||
connecter, errConnecter = TCPConnecterFromConfig(v)
|
||||
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.TLSConnect:
|
||||
connecter, errConnecter = TLSConnecterFromConfig(v)
|
||||
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.LocalConnect:
|
||||
connecter, errConnecter = LocalConnecterFromConfig(v)
|
||||
connConf, errRPC = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
default:
|
||||
panic(fmt.Sprintf("implementation error: unknown connecter type %T", v))
|
||||
}
|
||||
|
||||
if errConnecter != nil {
|
||||
return nil, errConnecter
|
||||
}
|
||||
if errRPC != nil {
|
||||
return nil, errRPC
|
||||
}
|
||||
|
||||
config := streamrpc.ClientConfig{ConnConfig: connConf}
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connecter = HandshakeConnecter{connecter}
|
||||
|
||||
return &ClientFactory{connecter: connecter, config: &config}, nil
|
||||
}
|
||||
|
||||
type ClientFactory struct {
|
||||
connecter streamrpc.Connecter
|
||||
config *streamrpc.ClientConfig
|
||||
}
|
||||
|
||||
func (f ClientFactory) NewClient() (*streamrpc.Client, error) {
|
||||
return streamrpc.NewClient(f.connecter, f.config)
|
||||
}
|
@ -1,136 +0,0 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type HandshakeMessage struct {
|
||||
ProtocolVersion int
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
func (m *HandshakeMessage) Encode() ([]byte, error) {
|
||||
if m.ProtocolVersion <= 0 || m.ProtocolVersion > 9999 {
|
||||
return nil, fmt.Errorf("protocol version must be in [1, 9999]")
|
||||
}
|
||||
if len(m.Extensions) >= 9999 {
|
||||
return nil, fmt.Errorf("protocol only supports [0, 9999] extensions")
|
||||
}
|
||||
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
|
||||
var extensions strings.Builder
|
||||
for i, ext := range m.Extensions {
|
||||
if strings.ContainsAny(ext, "\n") {
|
||||
return nil, fmt.Errorf("Extension #%d contains forbidden newline character", i)
|
||||
}
|
||||
if !utf8.ValidString(ext) {
|
||||
return nil, fmt.Errorf("Extension #%d is not valid UTF-8", i)
|
||||
}
|
||||
extensions.WriteString(ext)
|
||||
extensions.WriteString("\n")
|
||||
}
|
||||
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
|
||||
m.ProtocolVersion, len(m.Extensions), extensions.String())
|
||||
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
|
||||
return []byte(withLen), nil
|
||||
}
|
||||
|
||||
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
|
||||
var lenAndSpace [11]byte
|
||||
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if !utf8.Valid(lenAndSpace[:]) {
|
||||
return fmt.Errorf("invalid start of handshake message: not valid UTF-8")
|
||||
}
|
||||
var followLen int
|
||||
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
|
||||
if n != 1 || err != nil {
|
||||
return fmt.Errorf("could not parse handshake message length")
|
||||
}
|
||||
if followLen > maxLen {
|
||||
return fmt.Errorf("handshake message length exceeds max length (%d vs %d)",
|
||||
followLen, maxLen)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
protoVersion, extensionCount int
|
||||
)
|
||||
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
|
||||
&protoVersion, &extensionCount)
|
||||
if n != 2 || err != nil {
|
||||
return fmt.Errorf("could not parse handshake message: %s", err)
|
||||
}
|
||||
if protoVersion < 1 {
|
||||
return fmt.Errorf("invalid protocol version %q", protoVersion)
|
||||
}
|
||||
m.ProtocolVersion = protoVersion
|
||||
|
||||
if extensionCount < 0 {
|
||||
return fmt.Errorf("invalid extension count %q", extensionCount)
|
||||
}
|
||||
if extensionCount == 0 {
|
||||
if buf.Len() != 0 {
|
||||
return fmt.Errorf("unexpected data trailing after header")
|
||||
}
|
||||
m.Extensions = nil
|
||||
return nil
|
||||
}
|
||||
s := buf.String()
|
||||
if strings.Count(s, "\n") != extensionCount {
|
||||
return fmt.Errorf("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
|
||||
}
|
||||
exts := strings.Split(s, "\n")
|
||||
if exts[len(exts)-1] != "" {
|
||||
return fmt.Errorf("unexpected data trailing after last extension newline")
|
||||
}
|
||||
m.Extensions = exts[0:len(exts)-1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
|
||||
// current protocol version is hardcoded here
|
||||
return DoHandshakeVersion(conn, deadline, 1)
|
||||
}
|
||||
|
||||
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
|
||||
ours := HandshakeMessage{
|
||||
ProtocolVersion: version,
|
||||
Extensions: nil,
|
||||
}
|
||||
hsb, err := ours.Encode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not encode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
conn.SetDeadline(deadline)
|
||||
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not send protocol banner: %s", err)
|
||||
}
|
||||
|
||||
theirs := HandshakeMessage{}
|
||||
if err := theirs.DecodeReader(conn, 16 * 4096); err != nil { // FIXME constant
|
||||
return fmt.Errorf("could not decode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
if theirs.ProtocolVersion != ours.ProtocolVersion {
|
||||
return fmt.Errorf("protocol versions do not match: ours is %d, theirs is %d",
|
||||
ours.ProtocolVersion, theirs.ProtocolVersion)
|
||||
}
|
||||
// ignore extensions, we don't use them
|
||||
|
||||
return nil
|
||||
}
|
@ -1,147 +0,0 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/transport"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"context"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const contextKeyLog contextKey = 0
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLog, log)
|
||||
}
|
||||
|
||||
func getLogger(ctx context.Context) Logger {
|
||||
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
||||
return log
|
||||
}
|
||||
return logger.NewNullLogger()
|
||||
}
|
||||
|
||||
type AuthenticatedConn interface {
|
||||
net.Conn
|
||||
// ClientIdentity must be a string that satisfies ValidateClientIdentity
|
||||
ClientIdentity() string
|
||||
}
|
||||
|
||||
// A client identity must be a single component in a ZFS filesystem path
|
||||
func ValidateClientIdentity(in string) (err error) {
|
||||
path, err := zfs.NewDatasetPath(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if path.Length() != 1 {
|
||||
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type authConn struct {
|
||||
net.Conn
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
var _ AuthenticatedConn = authConn{}
|
||||
|
||||
func (c authConn) ClientIdentity() string {
|
||||
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c.clientIdentity
|
||||
}
|
||||
|
||||
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
||||
type AuthenticatedListener interface {
|
||||
Addr() (net.Addr)
|
||||
Accept(ctx context.Context) (AuthenticatedConn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ListenerFactory interface {
|
||||
Listen() (AuthenticatedListener, error)
|
||||
}
|
||||
|
||||
type HandshakeListenerFactory struct {
|
||||
lf ListenerFactory
|
||||
}
|
||||
|
||||
func (lf HandshakeListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := lf.lf.Listen()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return HandshakeListener{l}, nil
|
||||
}
|
||||
|
||||
type HandshakeListener struct {
|
||||
l AuthenticatedListener
|
||||
}
|
||||
|
||||
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
|
||||
|
||||
func (l HandshakeListener) Close() error { return l.l.Close() }
|
||||
|
||||
func (l HandshakeListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
conn, err := l.l.Accept(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(10*time.Second) // FIXME constant
|
||||
}
|
||||
if err := transport.DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
||||
|
||||
var (
|
||||
lfError, rpcErr error
|
||||
)
|
||||
switch v := in.Ret.(type) {
|
||||
case *config.TCPServe:
|
||||
lf, lfError = TCPListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.TLSServe:
|
||||
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.StdinserverServer:
|
||||
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.LocalServe:
|
||||
lf, lfError = LocalListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
default:
|
||||
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
||||
}
|
||||
|
||||
if lfError != nil {
|
||||
return nil, nil, lfError
|
||||
}
|
||||
if rpcErr != nil {
|
||||
return nil, nil, rpcErr
|
||||
}
|
||||
|
||||
lf = HandshakeListenerFactory{lf}
|
||||
|
||||
return lf, conf, nil
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,83 +0,0 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"net"
|
||||
"time"
|
||||
"context"
|
||||
)
|
||||
|
||||
type TLSListenerFactory struct {
|
||||
address string
|
||||
clientCA *x509.CertPool
|
||||
serverCert tls.Certificate
|
||||
handshakeTimeout time.Duration
|
||||
clientCNs map[string]struct{}
|
||||
}
|
||||
|
||||
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TLSListenerFactory, err error) {
|
||||
lf = &TLSListenerFactory{
|
||||
address: in.Listen,
|
||||
handshakeTimeout: in.HandshakeTimeout,
|
||||
}
|
||||
|
||||
if in.Ca == "" || in.Cert == "" || in.Key == "" {
|
||||
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
||||
}
|
||||
|
||||
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse ca file")
|
||||
}
|
||||
|
||||
lf.serverCert, err = tls.LoadX509KeyPair(in.Cert, in.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse cer/key pair")
|
||||
}
|
||||
|
||||
lf.clientCNs = make(map[string]struct{}, len(in.ClientCNs))
|
||||
for i, cn := range in.ClientCNs {
|
||||
if err := ValidateClientIdentity(cn); err != nil {
|
||||
return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn)
|
||||
}
|
||||
// dupes are ok fr now
|
||||
lf.clientCNs[cn] = struct{}{}
|
||||
}
|
||||
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := net.Listen("tcp", f.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
|
||||
return tlsAuthListener{tl, f.clientCNs}, nil
|
||||
}
|
||||
|
||||
type tlsAuthListener struct {
|
||||
*tlsconf.ClientAuthListener
|
||||
clientCNs map[string]struct{}
|
||||
}
|
||||
|
||||
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
c, cn, err := l.ClientAuthListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := l.clientCNs[cn]; !ok {
|
||||
if err := c.Close(); err != nil {
|
||||
getLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name")
|
||||
}
|
||||
return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, c.RemoteAddr())
|
||||
}
|
||||
return authConn{c, cn}, nil
|
||||
}
|
||||
|
||||
|
@ -203,13 +203,17 @@ The serve & connect configuration will thus look like the following:
|
||||
``ssh+stdinserver`` Transport
|
||||
-----------------------------
|
||||
|
||||
``ssh+stdinserver`` is inspired by `git shell <https://git-scm.com/docs/git-shell>`_ and `Borg Backup <https://borgbackup.readthedocs.io/en/stable/deployment.html>`_.
|
||||
It is provided by the Go package ``github.com/problame/go-netssh``.
|
||||
``ssh+stdinserver`` uses the ``ssh`` command and some features of the server-side SSH ``authorized_keys`` file.
|
||||
It is less efficient than other transports because the data passes through two more pipes.
|
||||
However, it is fairly convenient to set up and allows the zrepl daemon to not be directly exposed to the internet, because all traffic passes through the system's SSH server.
|
||||
|
||||
.. ATTENTION::
|
||||
The concept is inspired by `git shell <https://git-scm.com/docs/git-shell>`_ and `Borg Backup <https://borgbackup.readthedocs.io/en/stable/deployment.html>`_.
|
||||
The implementation is provided by the Go package ``github.com/problame/go-netssh``.
|
||||
|
||||
``ssh+stdinserver`` has inferior error detection and handling compared to the ``tcp`` and ``tls`` transports.
|
||||
If you require tested timeout & retry handling, use ``tcp`` or ``tls`` transports, or help improve package go-netssh.
|
||||
.. NOTE::
|
||||
|
||||
``ssh+stdinserver`` generally provides inferior error detection and handling compared to the ``tcp`` and ``tls`` transports.
|
||||
When encountering such problems, consider using ``tcp`` or ``tls`` transports, or help improve package go-netssh.
|
||||
|
||||
.. _transport-ssh+stdinserver-serve:
|
||||
|
||||
|
@ -9,6 +9,7 @@ type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyLogger contextKey = iota
|
||||
ClientIdentityKey
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
@ -2,16 +2,14 @@
|
||||
package endpoint
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"path"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/replication"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Sender implements replication.ReplicationEndpoint for a sending side
|
||||
@ -41,8 +39,8 @@ func (s *Sender) filterCheckFS(fs string) (*zfs.DatasetPath, error) {
|
||||
return dp, nil
|
||||
}
|
||||
|
||||
func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) {
|
||||
fss, err := zfs.ZFSListMapping(p.FSFilter)
|
||||
func (s *Sender) ListFilesystems(ctx context.Context, r *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
|
||||
fss, err := zfs.ZFSListMapping(ctx, s.FSFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -53,11 +51,12 @@ func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error)
|
||||
// FIXME: not supporting ResumeToken yet
|
||||
}
|
||||
}
|
||||
return rfss, nil
|
||||
res := &pdu.ListFilesystemRes{Filesystems: rfss, Empty: len(rfss) == 0}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) {
|
||||
lp, err := p.filterCheckFS(fs)
|
||||
func (s *Sender) ListFilesystemVersions(ctx context.Context, r *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
|
||||
lp, err := s.filterCheckFS(r.GetFilesystem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -69,16 +68,17 @@ func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.
|
||||
for i := range fsvs {
|
||||
rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i])
|
||||
}
|
||||
return rfsvs, nil
|
||||
res := &pdu.ListFilesystemVersionsRes{Versions: rfsvs}
|
||||
return res, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
_, err := p.filterCheckFS(r.Filesystem)
|
||||
func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
_, err := s.filterCheckFS(r.Filesystem)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if r.DryRun {
|
||||
si, err := zfs.ZFSSendDry(r.Filesystem, r.From, r.To, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -87,14 +87,17 @@ func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.Rea
|
||||
if si.SizeEstimate != -1 { // but si returns -1 for no size estimate
|
||||
expSize = si.SizeEstimate
|
||||
}
|
||||
return &pdu.SendRes{ExpectedSize: expSize}, nil, nil
|
||||
} else {
|
||||
stream, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "")
|
||||
res := &pdu.SendRes{ExpectedSize: expSize}
|
||||
|
||||
if r.DryRun {
|
||||
return res, nil, nil
|
||||
}
|
||||
|
||||
streamCopier, err := zfs.ZFSSend(ctx, r.Filesystem, r.From, r.To, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &pdu.SendRes{}, stream, nil
|
||||
}
|
||||
return res, streamCopier, nil
|
||||
}
|
||||
|
||||
func (p *Sender) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
|
||||
@ -132,6 +135,10 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
return nil, fmt.Errorf("sender does not implement Receive()")
|
||||
}
|
||||
|
||||
type FSFilter interface { // FIXME unused
|
||||
Filter(path *zfs.DatasetPath) (pass bool, err error)
|
||||
}
|
||||
@ -146,14 +153,50 @@ type FSMap interface { // FIXME unused
|
||||
|
||||
// Receiver implements replication.ReplicationEndpoint for a receiving side
|
||||
type Receiver struct {
|
||||
root *zfs.DatasetPath
|
||||
rootWithoutClientComponent *zfs.DatasetPath
|
||||
appendClientIdentity bool
|
||||
}
|
||||
|
||||
func NewReceiver(rootDataset *zfs.DatasetPath) (*Receiver, error) {
|
||||
func NewReceiver(rootDataset *zfs.DatasetPath, appendClientIdentity bool) *Receiver {
|
||||
if rootDataset.Length() <= 0 {
|
||||
return nil, errors.New("root dataset must not be an empty path")
|
||||
panic(fmt.Sprintf("root dataset must not be an empty path: %v", rootDataset))
|
||||
}
|
||||
return &Receiver{root: rootDataset.Copy()}, nil
|
||||
return &Receiver{rootWithoutClientComponent: rootDataset.Copy(), appendClientIdentity: appendClientIdentity}
|
||||
}
|
||||
|
||||
func TestClientIdentity(rootFS *zfs.DatasetPath, clientIdentity string) error {
|
||||
_, err := clientRoot(rootFS, clientIdentity)
|
||||
return err
|
||||
}
|
||||
|
||||
func clientRoot(rootFS *zfs.DatasetPath, clientIdentity string) (*zfs.DatasetPath, error) {
|
||||
rootFSLen := rootFS.Length()
|
||||
clientRootStr := path.Join(rootFS.ToString(), clientIdentity)
|
||||
clientRoot, err := zfs.NewDatasetPath(clientRootStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rootFSLen+1 != clientRoot.Length() {
|
||||
return nil, fmt.Errorf("client identity must be a single ZFS filesystem path component")
|
||||
}
|
||||
return clientRoot, nil
|
||||
}
|
||||
|
||||
func (s *Receiver) clientRootFromCtx(ctx context.Context) *zfs.DatasetPath {
|
||||
if !s.appendClientIdentity {
|
||||
return s.rootWithoutClientComponent.Copy()
|
||||
}
|
||||
|
||||
clientIdentity, ok := ctx.Value(ClientIdentityKey).(string)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("ClientIdentityKey context value must be set"))
|
||||
}
|
||||
|
||||
clientRoot, err := clientRoot(s.rootWithoutClientComponent, clientIdentity)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ClientIdentityContextKey must have been validated before invoking Receiver: %s", err))
|
||||
}
|
||||
return clientRoot
|
||||
}
|
||||
|
||||
type subroot struct {
|
||||
@ -180,8 +223,9 @@ func (f subroot) MapToLocal(fs string) (*zfs.DatasetPath, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) {
|
||||
filtered, err := zfs.ZFSListMapping(subroot{e.root})
|
||||
func (s *Receiver) ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
|
||||
root := s.clientRootFromCtx(ctx)
|
||||
filtered, err := zfs.ZFSListMapping(ctx, subroot{root})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -194,19 +238,30 @@ func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, erro
|
||||
WithError(err).
|
||||
WithField("fs", a).
|
||||
Error("inconsistent placeholder property")
|
||||
return nil, errors.New("server error, see logs") // don't leak path
|
||||
return nil, errors.New("server error: inconsistent placeholder property") // don't leak path
|
||||
}
|
||||
if ph {
|
||||
getLogger(ctx).
|
||||
WithField("fs", a.ToString()).
|
||||
Debug("ignoring placeholder filesystem")
|
||||
continue
|
||||
}
|
||||
a.TrimPrefix(e.root)
|
||||
getLogger(ctx).
|
||||
WithField("fs", a.ToString()).
|
||||
Debug("non-placeholder filesystem")
|
||||
a.TrimPrefix(root)
|
||||
fss = append(fss, &pdu.Filesystem{Path: a.ToString()})
|
||||
}
|
||||
return fss, nil
|
||||
if len(fss) == 0 {
|
||||
getLogger(ctx).Debug("no non-placeholder filesystems")
|
||||
return &pdu.ListFilesystemRes{Empty: true}, nil
|
||||
}
|
||||
return &pdu.ListFilesystemRes{Filesystems: fss}, nil
|
||||
}
|
||||
|
||||
func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) {
|
||||
lp, err := subroot{e.root}.MapToLocal(fs)
|
||||
func (s *Receiver) ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
|
||||
root := s.clientRootFromCtx(ctx)
|
||||
lp, err := subroot{root}.MapToLocal(req.GetFilesystem())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -221,18 +276,26 @@ func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pd
|
||||
rfsvs[i] = pdu.FilesystemVersionFromZFS(&fsvs[i])
|
||||
}
|
||||
|
||||
return rfsvs, nil
|
||||
return &pdu.ListFilesystemVersionsRes{Versions: rfsvs}, nil
|
||||
}
|
||||
|
||||
func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error {
|
||||
defer sendStream.Close()
|
||||
|
||||
lp, err := subroot{e.root}.MapToLocal(req.Filesystem)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
|
||||
return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver")
|
||||
}
|
||||
|
||||
func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
return nil, nil, fmt.Errorf("receiver does not implement Send()")
|
||||
}
|
||||
|
||||
func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
getLogger(ctx).Debug("incoming Receive")
|
||||
defer receive.Close()
|
||||
|
||||
root := s.clientRootFromCtx(ctx)
|
||||
lp, err := subroot{root}.MapToLocal(req.Filesystem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create placeholder parent filesystems as appropriate
|
||||
var visitErr error
|
||||
@ -261,7 +324,7 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream
|
||||
getLogger(ctx).WithField("visitErr", visitErr).Debug("complete tree-walk")
|
||||
|
||||
if visitErr != nil {
|
||||
return visitErr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
needForceRecv := false
|
||||
@ -279,19 +342,19 @@ func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream
|
||||
|
||||
getLogger(ctx).Debug("start receive command")
|
||||
|
||||
if err := zfs.ZFSRecv(ctx, lp.ToString(), sendStream, args...); err != nil {
|
||||
if err := zfs.ZFSRecv(ctx, lp.ToString(), receive, args...); err != nil {
|
||||
getLogger(ctx).
|
||||
WithError(err).
|
||||
WithField("args", args).
|
||||
Error("zfs receive failed")
|
||||
sendStream.Close()
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
return &pdu.ReceiveRes{}, nil
|
||||
}
|
||||
|
||||
func (e *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
|
||||
lp, err := subroot{e.root}.MapToLocal(req.Filesystem)
|
||||
func (s *Receiver) DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
|
||||
root := s.clientRootFromCtx(ctx)
|
||||
lp, err := subroot{root}.MapToLocal(req.Filesystem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -326,289 +389,3 @@ func doDestroySnapshots(ctx context.Context, lp *zfs.DatasetPath, snaps []*pdu.F
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
// RPC STUBS
|
||||
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||
|
||||
const (
|
||||
RPCListFilesystems = "ListFilesystems"
|
||||
RPCListFilesystemVersions = "ListFilesystemVersions"
|
||||
RPCReceive = "Receive"
|
||||
RPCSend = "Send"
|
||||
RPCSDestroySnapshots = "DestroySnapshots"
|
||||
RPCReplicationCursor = "ReplicationCursor"
|
||||
)
|
||||
|
||||
// Remote implements an endpoint stub that uses streamrpc as a transport.
|
||||
type Remote struct {
|
||||
c *streamrpc.Client
|
||||
}
|
||||
|
||||
func NewRemote(c *streamrpc.Client) Remote {
|
||||
return Remote{c}
|
||||
}
|
||||
|
||||
func (s Remote) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) {
|
||||
req := pdu.ListFilesystemReq{}
|
||||
b, err := proto.Marshal(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs != nil {
|
||||
rs.Close()
|
||||
return nil, errors.New("response contains unexpected stream")
|
||||
}
|
||||
var res pdu.ListFilesystemRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Filesystems, nil
|
||||
}
|
||||
|
||||
func (s Remote) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) {
|
||||
req := pdu.ListFilesystemVersionsReq{
|
||||
Filesystem: fs,
|
||||
}
|
||||
b, err := proto.Marshal(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs != nil {
|
||||
rs.Close()
|
||||
return nil, errors.New("response contains unexpected stream")
|
||||
}
|
||||
var res pdu.ListFilesystemVersionsRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Versions, nil
|
||||
}
|
||||
|
||||
func (s Remote) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
b, err := proto.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !r.DryRun && rs == nil {
|
||||
return nil, nil, errors.New("response does not contain a stream")
|
||||
}
|
||||
if r.DryRun && rs != nil {
|
||||
rs.Close()
|
||||
return nil, nil, errors.New("response contains unexpected stream (was dry run)")
|
||||
}
|
||||
var res pdu.SendRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
rs.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return &res, rs, nil
|
||||
}
|
||||
|
||||
func (s Remote) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error {
|
||||
defer sendStream.Close()
|
||||
b, err := proto.Marshal(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream)
|
||||
getLogger(ctx).WithField("err", err).Debug("Remote.Receive RequestReplyReturned")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rs != nil {
|
||||
rs.Close()
|
||||
return errors.New("response contains unexpected stream")
|
||||
}
|
||||
var res pdu.ReceiveRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Remote) DestroySnapshots(ctx context.Context, r *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
|
||||
b, err := proto.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCSDestroySnapshots, bytes.NewBuffer(b), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs != nil {
|
||||
rs.Close()
|
||||
return nil, errors.New("response contains unexpected stream")
|
||||
}
|
||||
var res pdu.DestroySnapshotsRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (s Remote) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
|
||||
b, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rb, rs, err := s.c.RequestReply(ctx, RPCReplicationCursor, bytes.NewBuffer(b), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rs != nil {
|
||||
rs.Close()
|
||||
return nil, errors.New("response contains unexpected stream")
|
||||
}
|
||||
var res pdu.ReplicationCursorRes
|
||||
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// Handler implements the server-side streamrpc.HandlerFunc for a Remote endpoint stub.
|
||||
type Handler struct {
|
||||
ep replication.Endpoint
|
||||
}
|
||||
|
||||
func NewHandler(ep replication.Endpoint) Handler {
|
||||
return Handler{ep}
|
||||
}
|
||||
|
||||
func (a *Handler) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) {
|
||||
|
||||
switch endpoint {
|
||||
case RPCListFilesystems:
|
||||
var req pdu.ListFilesystemReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
fsses, err := a.ep.ListFilesystems(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
res := &pdu.ListFilesystemRes{
|
||||
Filesystems: fsses,
|
||||
}
|
||||
b, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), nil, nil
|
||||
|
||||
case RPCListFilesystemVersions:
|
||||
|
||||
var req pdu.ListFilesystemVersionsReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
res := &pdu.ListFilesystemVersionsRes{
|
||||
Versions: fsvs,
|
||||
}
|
||||
b, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), nil, nil
|
||||
|
||||
case RPCSend:
|
||||
|
||||
sender, ok := a.ep.(replication.Sender)
|
||||
if !ok {
|
||||
goto Err
|
||||
}
|
||||
|
||||
var req pdu.SendReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
res, sendStream, err := sender.Send(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), sendStream, err
|
||||
|
||||
case RPCReceive:
|
||||
|
||||
receiver, ok := a.ep.(replication.Receiver)
|
||||
if !ok {
|
||||
goto Err
|
||||
}
|
||||
|
||||
var req pdu.ReceiveReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
err := receiver.Receive(ctx, &req, reqStream)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b, err := proto.Marshal(&pdu.ReceiveRes{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), nil, err
|
||||
|
||||
case RPCSDestroySnapshots:
|
||||
|
||||
var req pdu.DestroySnapshotsReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
res, err := a.ep.DestroySnapshots(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), nil, nil
|
||||
|
||||
case RPCReplicationCursor:
|
||||
|
||||
sender, ok := a.ep.(replication.Sender)
|
||||
if !ok {
|
||||
goto Err
|
||||
}
|
||||
|
||||
var req pdu.ReplicationCursorReq
|
||||
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
res, err := sender.ReplicationCursor(ctx, &req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
b, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bytes.NewBuffer(b), nil, nil
|
||||
|
||||
}
|
||||
Err:
|
||||
return nil, nil, errors.New("no handler for given endpoint")
|
||||
}
|
||||
|
27
logger/stderrlogger.go
Normal file
27
logger/stderrlogger.go
Normal file
@ -0,0 +1,27 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type stderrLogger struct {
|
||||
Logger
|
||||
}
|
||||
|
||||
type stderrLoggerOutlet struct {}
|
||||
|
||||
func (stderrLoggerOutlet) WriteEntry(entry Entry) error {
|
||||
fmt.Fprintf(os.Stderr, "%#v\n", entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Logger = testLogger{}
|
||||
|
||||
func NewStderrDebugLogger() Logger {
|
||||
outlets := NewOutlets()
|
||||
outlets.Add(&stderrLoggerOutlet{}, Debug)
|
||||
return &testLogger{
|
||||
Logger: NewLogger(outlets, 0),
|
||||
}
|
||||
}
|
@ -6,16 +6,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/zrepl/zrepl/util/watchdog"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"github.com/zrepl/zrepl/util/bytecounter"
|
||||
"github.com/zrepl/zrepl/util/watchdog"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
@ -43,7 +43,7 @@ type Sender interface {
|
||||
// If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before
|
||||
// any next call to the parent github.com/zrepl/zrepl/replication.Endpoint.
|
||||
// If the send request is for dry run the io.ReadCloser will be nil
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
|
||||
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
|
||||
}
|
||||
|
||||
@ -51,9 +51,7 @@ type Sender interface {
|
||||
type Receiver interface {
|
||||
// Receive sends r and sendStream (the latter containing a ZFS send stream)
|
||||
// to the parent github.com/zrepl/zrepl/replication.Endpoint.
|
||||
// Implementors must guarantee that Close was called on sendStream before
|
||||
// the call to Receive returns.
|
||||
Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) error
|
||||
Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
|
||||
}
|
||||
|
||||
type StepReport struct {
|
||||
@ -227,7 +225,7 @@ type ReplicationStep struct {
|
||||
// both retry and permanent error
|
||||
err error
|
||||
|
||||
byteCounter *util.ByteCounterReader
|
||||
byteCounter bytecounter.StreamCopier
|
||||
expectedSize int64 // 0 means no size estimate present / possible
|
||||
}
|
||||
|
||||
@ -401,37 +399,54 @@ func (s *ReplicationStep) doReplication(ctx context.Context, ka *watchdog.KeepAl
|
||||
sr := s.buildSendRequest(false)
|
||||
|
||||
log.Debug("initiate send request")
|
||||
sres, sstream, err := sender.Send(ctx, sr)
|
||||
sres, sstreamCopier, err := sender.Send(ctx, sr)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("send request failed")
|
||||
return err
|
||||
}
|
||||
if sstream == nil {
|
||||
if sstreamCopier == nil {
|
||||
err := errors.New("send request did not return a stream, broken endpoint implementation")
|
||||
return err
|
||||
}
|
||||
defer sstreamCopier.Close()
|
||||
|
||||
s.byteCounter = util.NewByteCounterReader(sstream)
|
||||
s.byteCounter.SetCallback(1*time.Second, func(i int64) {
|
||||
// Install a byte counter to track progress + for status report
|
||||
s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier)
|
||||
byteCounterStopProgress := make(chan struct{})
|
||||
defer close(byteCounterStopProgress)
|
||||
go func() {
|
||||
var lastCount int64
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-byteCounterStopProgress:
|
||||
return
|
||||
case <-t.C:
|
||||
newCount := s.byteCounter.Count()
|
||||
if lastCount != newCount {
|
||||
ka.MadeProgress()
|
||||
})
|
||||
defer func() {
|
||||
s.parent.promBytesReplicated.Add(float64(s.byteCounter.Bytes()))
|
||||
} else {
|
||||
lastCount = newCount
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count()))
|
||||
}()
|
||||
sstream = s.byteCounter
|
||||
|
||||
rr := &pdu.ReceiveReq{
|
||||
Filesystem: fs,
|
||||
ClearResumeToken: !sres.UsedResumeToken,
|
||||
}
|
||||
log.Debug("initiate receive request")
|
||||
err = receiver.Receive(ctx, rr, sstream)
|
||||
_, err = receiver.Receive(ctx, rr, s.byteCounter)
|
||||
if err != nil {
|
||||
log.
|
||||
WithError(err).
|
||||
WithField("errType", fmt.Sprintf("%T", err)).
|
||||
Error("receive request failed (might also be error on sender)")
|
||||
sstream.Close()
|
||||
// This failure could be due to
|
||||
// - an unexpected exit of ZFS on the sending side
|
||||
// - an unexpected exit of ZFS on the receiving side
|
||||
@ -524,7 +539,7 @@ func (s *ReplicationStep) Report() *StepReport {
|
||||
}
|
||||
bytes := int64(0)
|
||||
if s.byteCounter != nil {
|
||||
bytes = s.byteCounter.Bytes()
|
||||
bytes = s.byteCounter.Count()
|
||||
}
|
||||
problem := ""
|
||||
if s.err != nil {
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"github.com/zrepl/zrepl/daemon/job/wakeup"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"github.com/zrepl/zrepl/util/watchdog"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"math/bits"
|
||||
"net"
|
||||
"sort"
|
||||
@ -106,9 +105,8 @@ func NewReplication(secsPerState *prometheus.HistogramVec, bytesReplicated *prom
|
||||
// named interfaces defined in this package.
|
||||
type Endpoint interface {
|
||||
// Does not include placeholder filesystems
|
||||
ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error)
|
||||
// FIXME document FilteredError handling
|
||||
ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) // fix depS
|
||||
ListFilesystems(ctx context.Context, req *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error)
|
||||
ListFilesystemVersions(ctx context.Context, req *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error)
|
||||
DestroySnapshots(ctx context.Context, req *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error)
|
||||
}
|
||||
|
||||
@ -203,7 +201,6 @@ type Error interface {
|
||||
|
||||
var _ Error = fsrep.Error(nil)
|
||||
var _ Error = net.Error(nil)
|
||||
var _ Error = streamrpc.Error(nil)
|
||||
|
||||
func isPermanent(err error) bool {
|
||||
if e, ok := err.(Error); ok {
|
||||
@ -232,19 +229,20 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
|
||||
}).rsf()
|
||||
}
|
||||
|
||||
sfss, err := sender.ListFilesystems(ctx)
|
||||
slfssres, err := sender.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error listing sender filesystems")
|
||||
log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing sender filesystems")
|
||||
return handlePlanningError(err)
|
||||
}
|
||||
sfss := slfssres.GetFilesystems()
|
||||
// no progress here since we could run in a live-lock on connectivity issues
|
||||
|
||||
rfss, err := receiver.ListFilesystems(ctx)
|
||||
rlfssres, err := receiver.ListFilesystems(ctx, &pdu.ListFilesystemReq{})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error listing receiver filesystems")
|
||||
log.WithError(err).WithField("errType", fmt.Sprintf("%T", err)).Error("error listing receiver filesystems")
|
||||
return handlePlanningError(err)
|
||||
}
|
||||
|
||||
rfss := rlfssres.GetFilesystems()
|
||||
ka.MadeProgress() // for both sender and receiver
|
||||
|
||||
q := make([]*fsrep.Replication, 0, len(sfss))
|
||||
@ -255,11 +253,12 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
|
||||
|
||||
log.Debug("assessing filesystem")
|
||||
|
||||
sfsvs, err := sender.ListFilesystemVersions(ctx, fs.Path)
|
||||
sfsvsres, err := sender.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path})
|
||||
if err != nil {
|
||||
log.WithError(err).Error("cannot get remote filesystem versions")
|
||||
return handlePlanningError(err)
|
||||
}
|
||||
sfsvs := sfsvsres.GetVersions()
|
||||
ka.MadeProgress()
|
||||
|
||||
if len(sfsvs) < 1 {
|
||||
@ -278,7 +277,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
|
||||
|
||||
var rfsvs []*pdu.FilesystemVersion
|
||||
if receiverFSExists {
|
||||
rfsvs, err = receiver.ListFilesystemVersions(ctx, fs.Path)
|
||||
rfsvsres, err := receiver.ListFilesystemVersions(ctx, &pdu.ListFilesystemVersionsReq{Filesystem: fs.Path})
|
||||
if err != nil {
|
||||
if _, ok := err.(*FilteredError); ok {
|
||||
log.Info("receiver ignores filesystem")
|
||||
@ -287,6 +286,7 @@ func statePlanning(ctx context.Context, ka *watchdog.KeepAlive, sender Sender, r
|
||||
log.WithError(err).Error("receiver error")
|
||||
return handlePlanningError(err)
|
||||
}
|
||||
rfsvs = rfsvsres.GetVersions()
|
||||
} else {
|
||||
rfsvs = []*pdu.FilesystemVersion{}
|
||||
}
|
||||
|
@ -7,6 +7,11 @@ import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
@ -38,7 +43,7 @@ func (x FilesystemVersion_VersionType) String() string {
|
||||
return proto.EnumName(FilesystemVersion_VersionType_name, int32(x))
|
||||
}
|
||||
func (FilesystemVersion_VersionType) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5, 0}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{5, 0}
|
||||
}
|
||||
|
||||
type ListFilesystemReq struct {
|
||||
@ -51,7 +56,7 @@ func (m *ListFilesystemReq) Reset() { *m = ListFilesystemReq{} }
|
||||
func (m *ListFilesystemReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*ListFilesystemReq) ProtoMessage() {}
|
||||
func (*ListFilesystemReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{0}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{0}
|
||||
}
|
||||
func (m *ListFilesystemReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ListFilesystemReq.Unmarshal(m, b)
|
||||
@ -73,6 +78,7 @@ var xxx_messageInfo_ListFilesystemReq proto.InternalMessageInfo
|
||||
|
||||
type ListFilesystemRes struct {
|
||||
Filesystems []*Filesystem `protobuf:"bytes,1,rep,name=Filesystems,proto3" json:"Filesystems,omitempty"`
|
||||
Empty bool `protobuf:"varint,2,opt,name=Empty,proto3" json:"Empty,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
@ -82,7 +88,7 @@ func (m *ListFilesystemRes) Reset() { *m = ListFilesystemRes{} }
|
||||
func (m *ListFilesystemRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*ListFilesystemRes) ProtoMessage() {}
|
||||
func (*ListFilesystemRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{1}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{1}
|
||||
}
|
||||
func (m *ListFilesystemRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ListFilesystemRes.Unmarshal(m, b)
|
||||
@ -109,6 +115,13 @@ func (m *ListFilesystemRes) GetFilesystems() []*Filesystem {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ListFilesystemRes) GetEmpty() bool {
|
||||
if m != nil {
|
||||
return m.Empty
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type Filesystem struct {
|
||||
Path string `protobuf:"bytes,1,opt,name=Path,proto3" json:"Path,omitempty"`
|
||||
ResumeToken string `protobuf:"bytes,2,opt,name=ResumeToken,proto3" json:"ResumeToken,omitempty"`
|
||||
@ -121,7 +134,7 @@ func (m *Filesystem) Reset() { *m = Filesystem{} }
|
||||
func (m *Filesystem) String() string { return proto.CompactTextString(m) }
|
||||
func (*Filesystem) ProtoMessage() {}
|
||||
func (*Filesystem) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{2}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{2}
|
||||
}
|
||||
func (m *Filesystem) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Filesystem.Unmarshal(m, b)
|
||||
@ -166,7 +179,7 @@ func (m *ListFilesystemVersionsReq) Reset() { *m = ListFilesystemVersion
|
||||
func (m *ListFilesystemVersionsReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*ListFilesystemVersionsReq) ProtoMessage() {}
|
||||
func (*ListFilesystemVersionsReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{3}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{3}
|
||||
}
|
||||
func (m *ListFilesystemVersionsReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ListFilesystemVersionsReq.Unmarshal(m, b)
|
||||
@ -204,7 +217,7 @@ func (m *ListFilesystemVersionsRes) Reset() { *m = ListFilesystemVersion
|
||||
func (m *ListFilesystemVersionsRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*ListFilesystemVersionsRes) ProtoMessage() {}
|
||||
func (*ListFilesystemVersionsRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{4}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{4}
|
||||
}
|
||||
func (m *ListFilesystemVersionsRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ListFilesystemVersionsRes.Unmarshal(m, b)
|
||||
@ -232,7 +245,7 @@ func (m *ListFilesystemVersionsRes) GetVersions() []*FilesystemVersion {
|
||||
}
|
||||
|
||||
type FilesystemVersion struct {
|
||||
Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=pdu.FilesystemVersion_VersionType" json:"Type,omitempty"`
|
||||
Type FilesystemVersion_VersionType `protobuf:"varint,1,opt,name=Type,proto3,enum=FilesystemVersion_VersionType" json:"Type,omitempty"`
|
||||
Name string `protobuf:"bytes,2,opt,name=Name,proto3" json:"Name,omitempty"`
|
||||
Guid uint64 `protobuf:"varint,3,opt,name=Guid,proto3" json:"Guid,omitempty"`
|
||||
CreateTXG uint64 `protobuf:"varint,4,opt,name=CreateTXG,proto3" json:"CreateTXG,omitempty"`
|
||||
@ -246,7 +259,7 @@ func (m *FilesystemVersion) Reset() { *m = FilesystemVersion{} }
|
||||
func (m *FilesystemVersion) String() string { return proto.CompactTextString(m) }
|
||||
func (*FilesystemVersion) ProtoMessage() {}
|
||||
func (*FilesystemVersion) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{5}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{5}
|
||||
}
|
||||
func (m *FilesystemVersion) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_FilesystemVersion.Unmarshal(m, b)
|
||||
@ -326,7 +339,7 @@ func (m *SendReq) Reset() { *m = SendReq{} }
|
||||
func (m *SendReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*SendReq) ProtoMessage() {}
|
||||
func (*SendReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{6}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{6}
|
||||
}
|
||||
func (m *SendReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_SendReq.Unmarshal(m, b)
|
||||
@ -407,7 +420,7 @@ func (m *Property) Reset() { *m = Property{} }
|
||||
func (m *Property) String() string { return proto.CompactTextString(m) }
|
||||
func (*Property) ProtoMessage() {}
|
||||
func (*Property) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{7}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{7}
|
||||
}
|
||||
func (m *Property) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Property.Unmarshal(m, b)
|
||||
@ -443,11 +456,11 @@ func (m *Property) GetValue() string {
|
||||
|
||||
type SendRes struct {
|
||||
// Whether the resume token provided in the request has been used or not.
|
||||
UsedResumeToken bool `protobuf:"varint,1,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"`
|
||||
UsedResumeToken bool `protobuf:"varint,2,opt,name=UsedResumeToken,proto3" json:"UsedResumeToken,omitempty"`
|
||||
// Expected stream size determined by dry run, not exact.
|
||||
// 0 indicates that for the given SendReq, no size estimate could be made.
|
||||
ExpectedSize int64 `protobuf:"varint,2,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"`
|
||||
Properties []*Property `protobuf:"bytes,3,rep,name=Properties,proto3" json:"Properties,omitempty"`
|
||||
ExpectedSize int64 `protobuf:"varint,3,opt,name=ExpectedSize,proto3" json:"ExpectedSize,omitempty"`
|
||||
Properties []*Property `protobuf:"bytes,4,rep,name=Properties,proto3" json:"Properties,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
@ -457,7 +470,7 @@ func (m *SendRes) Reset() { *m = SendRes{} }
|
||||
func (m *SendRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*SendRes) ProtoMessage() {}
|
||||
func (*SendRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{8}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{8}
|
||||
}
|
||||
func (m *SendRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_SendRes.Unmarshal(m, b)
|
||||
@ -511,7 +524,7 @@ func (m *ReceiveReq) Reset() { *m = ReceiveReq{} }
|
||||
func (m *ReceiveReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReceiveReq) ProtoMessage() {}
|
||||
func (*ReceiveReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{9}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{9}
|
||||
}
|
||||
func (m *ReceiveReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReceiveReq.Unmarshal(m, b)
|
||||
@ -555,7 +568,7 @@ func (m *ReceiveRes) Reset() { *m = ReceiveRes{} }
|
||||
func (m *ReceiveRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReceiveRes) ProtoMessage() {}
|
||||
func (*ReceiveRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{10}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{10}
|
||||
}
|
||||
func (m *ReceiveRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReceiveRes.Unmarshal(m, b)
|
||||
@ -588,7 +601,7 @@ func (m *DestroySnapshotsReq) Reset() { *m = DestroySnapshotsReq{} }
|
||||
func (m *DestroySnapshotsReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*DestroySnapshotsReq) ProtoMessage() {}
|
||||
func (*DestroySnapshotsReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{11}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{11}
|
||||
}
|
||||
func (m *DestroySnapshotsReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_DestroySnapshotsReq.Unmarshal(m, b)
|
||||
@ -634,7 +647,7 @@ func (m *DestroySnapshotRes) Reset() { *m = DestroySnapshotRes{} }
|
||||
func (m *DestroySnapshotRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*DestroySnapshotRes) ProtoMessage() {}
|
||||
func (*DestroySnapshotRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{12}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{12}
|
||||
}
|
||||
func (m *DestroySnapshotRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_DestroySnapshotRes.Unmarshal(m, b)
|
||||
@ -679,7 +692,7 @@ func (m *DestroySnapshotsRes) Reset() { *m = DestroySnapshotsRes{} }
|
||||
func (m *DestroySnapshotsRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*DestroySnapshotsRes) ProtoMessage() {}
|
||||
func (*DestroySnapshotsRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{13}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{13}
|
||||
}
|
||||
func (m *DestroySnapshotsRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_DestroySnapshotsRes.Unmarshal(m, b)
|
||||
@ -721,7 +734,7 @@ func (m *ReplicationCursorReq) Reset() { *m = ReplicationCursorReq{} }
|
||||
func (m *ReplicationCursorReq) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReplicationCursorReq) ProtoMessage() {}
|
||||
func (*ReplicationCursorReq) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{14}
|
||||
}
|
||||
func (m *ReplicationCursorReq) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReplicationCursorReq.Unmarshal(m, b)
|
||||
@ -869,7 +882,7 @@ func (m *ReplicationCursorReq_GetOp) Reset() { *m = ReplicationCursorReq
|
||||
func (m *ReplicationCursorReq_GetOp) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReplicationCursorReq_GetOp) ProtoMessage() {}
|
||||
func (*ReplicationCursorReq_GetOp) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 0}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{14, 0}
|
||||
}
|
||||
func (m *ReplicationCursorReq_GetOp) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReplicationCursorReq_GetOp.Unmarshal(m, b)
|
||||
@ -900,7 +913,7 @@ func (m *ReplicationCursorReq_SetOp) Reset() { *m = ReplicationCursorReq
|
||||
func (m *ReplicationCursorReq_SetOp) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReplicationCursorReq_SetOp) ProtoMessage() {}
|
||||
func (*ReplicationCursorReq_SetOp) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{14, 1}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{14, 1}
|
||||
}
|
||||
func (m *ReplicationCursorReq_SetOp) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReplicationCursorReq_SetOp.Unmarshal(m, b)
|
||||
@ -941,7 +954,7 @@ func (m *ReplicationCursorRes) Reset() { *m = ReplicationCursorRes{} }
|
||||
func (m *ReplicationCursorRes) String() string { return proto.CompactTextString(m) }
|
||||
func (*ReplicationCursorRes) ProtoMessage() {}
|
||||
func (*ReplicationCursorRes) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_pdu_fe566e6b212fcf8d, []int{15}
|
||||
return fileDescriptor_pdu_89315d819a6e0938, []int{15}
|
||||
}
|
||||
func (m *ReplicationCursorRes) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ReplicationCursorRes.Unmarshal(m, b)
|
||||
@ -1067,71 +1080,246 @@ func _ReplicationCursorRes_OneofSizer(msg proto.Message) (n int) {
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*ListFilesystemReq)(nil), "pdu.ListFilesystemReq")
|
||||
proto.RegisterType((*ListFilesystemRes)(nil), "pdu.ListFilesystemRes")
|
||||
proto.RegisterType((*Filesystem)(nil), "pdu.Filesystem")
|
||||
proto.RegisterType((*ListFilesystemVersionsReq)(nil), "pdu.ListFilesystemVersionsReq")
|
||||
proto.RegisterType((*ListFilesystemVersionsRes)(nil), "pdu.ListFilesystemVersionsRes")
|
||||
proto.RegisterType((*FilesystemVersion)(nil), "pdu.FilesystemVersion")
|
||||
proto.RegisterType((*SendReq)(nil), "pdu.SendReq")
|
||||
proto.RegisterType((*Property)(nil), "pdu.Property")
|
||||
proto.RegisterType((*SendRes)(nil), "pdu.SendRes")
|
||||
proto.RegisterType((*ReceiveReq)(nil), "pdu.ReceiveReq")
|
||||
proto.RegisterType((*ReceiveRes)(nil), "pdu.ReceiveRes")
|
||||
proto.RegisterType((*DestroySnapshotsReq)(nil), "pdu.DestroySnapshotsReq")
|
||||
proto.RegisterType((*DestroySnapshotRes)(nil), "pdu.DestroySnapshotRes")
|
||||
proto.RegisterType((*DestroySnapshotsRes)(nil), "pdu.DestroySnapshotsRes")
|
||||
proto.RegisterType((*ReplicationCursorReq)(nil), "pdu.ReplicationCursorReq")
|
||||
proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "pdu.ReplicationCursorReq.GetOp")
|
||||
proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "pdu.ReplicationCursorReq.SetOp")
|
||||
proto.RegisterType((*ReplicationCursorRes)(nil), "pdu.ReplicationCursorRes")
|
||||
proto.RegisterEnum("pdu.FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value)
|
||||
proto.RegisterType((*ListFilesystemReq)(nil), "ListFilesystemReq")
|
||||
proto.RegisterType((*ListFilesystemRes)(nil), "ListFilesystemRes")
|
||||
proto.RegisterType((*Filesystem)(nil), "Filesystem")
|
||||
proto.RegisterType((*ListFilesystemVersionsReq)(nil), "ListFilesystemVersionsReq")
|
||||
proto.RegisterType((*ListFilesystemVersionsRes)(nil), "ListFilesystemVersionsRes")
|
||||
proto.RegisterType((*FilesystemVersion)(nil), "FilesystemVersion")
|
||||
proto.RegisterType((*SendReq)(nil), "SendReq")
|
||||
proto.RegisterType((*Property)(nil), "Property")
|
||||
proto.RegisterType((*SendRes)(nil), "SendRes")
|
||||
proto.RegisterType((*ReceiveReq)(nil), "ReceiveReq")
|
||||
proto.RegisterType((*ReceiveRes)(nil), "ReceiveRes")
|
||||
proto.RegisterType((*DestroySnapshotsReq)(nil), "DestroySnapshotsReq")
|
||||
proto.RegisterType((*DestroySnapshotRes)(nil), "DestroySnapshotRes")
|
||||
proto.RegisterType((*DestroySnapshotsRes)(nil), "DestroySnapshotsRes")
|
||||
proto.RegisterType((*ReplicationCursorReq)(nil), "ReplicationCursorReq")
|
||||
proto.RegisterType((*ReplicationCursorReq_GetOp)(nil), "ReplicationCursorReq.GetOp")
|
||||
proto.RegisterType((*ReplicationCursorReq_SetOp)(nil), "ReplicationCursorReq.SetOp")
|
||||
proto.RegisterType((*ReplicationCursorRes)(nil), "ReplicationCursorRes")
|
||||
proto.RegisterEnum("FilesystemVersion_VersionType", FilesystemVersion_VersionType_name, FilesystemVersion_VersionType_value)
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_fe566e6b212fcf8d) }
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
var fileDescriptor_pdu_fe566e6b212fcf8d = []byte{
|
||||
// 659 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdb, 0x6e, 0x13, 0x31,
|
||||
0x10, 0xcd, 0xe6, 0xba, 0x99, 0x94, 0x5e, 0xdc, 0xaa, 0x2c, 0x15, 0x82, 0xc8, 0xbc, 0x04, 0x24,
|
||||
0x22, 0x91, 0x56, 0xbc, 0xf0, 0x96, 0xde, 0xf2, 0x80, 0xda, 0xca, 0x09, 0x55, 0x9f, 0x90, 0x42,
|
||||
0x77, 0x44, 0x57, 0xb9, 0x78, 0x6b, 0x7b, 0x51, 0xc3, 0x07, 0xf0, 0x4f, 0xfc, 0x07, 0x0f, 0x7c,
|
||||
0x0e, 0xf2, 0xec, 0x25, 0xdb, 0x24, 0x54, 0x79, 0x8a, 0xcf, 0xf8, 0x78, 0xe6, 0xcc, 0xf1, 0x8e,
|
||||
0x03, 0xf5, 0xd0, 0x8f, 0xda, 0xa1, 0x92, 0x46, 0xb2, 0x52, 0xe8, 0x47, 0x7c, 0x17, 0x76, 0x3e,
|
||||
0x07, 0xda, 0x9c, 0x05, 0x63, 0xd4, 0x33, 0x6d, 0x70, 0x22, 0xf0, 0x9e, 0x9f, 0x2d, 0x07, 0x35,
|
||||
0xfb, 0x00, 0x8d, 0x79, 0x40, 0x7b, 0x4e, 0xb3, 0xd4, 0x6a, 0x74, 0xb6, 0xda, 0x36, 0x5f, 0x8e,
|
||||
0x98, 0xe7, 0xf0, 0x2e, 0xc0, 0x1c, 0x32, 0x06, 0xe5, 0xab, 0xa1, 0xb9, 0xf3, 0x9c, 0xa6, 0xd3,
|
||||
0xaa, 0x0b, 0x5a, 0xb3, 0x26, 0x34, 0x04, 0xea, 0x68, 0x82, 0x03, 0x39, 0xc2, 0xa9, 0x57, 0xa4,
|
||||
0xad, 0x7c, 0x88, 0x7f, 0x82, 0x17, 0x8f, 0xb5, 0x5c, 0xa3, 0xd2, 0x81, 0x9c, 0x6a, 0x81, 0xf7,
|
||||
0xec, 0x55, 0xbe, 0x40, 0x92, 0x38, 0x17, 0xe1, 0x97, 0xff, 0x3f, 0xac, 0x59, 0x07, 0xdc, 0x14,
|
||||
0x26, 0xdd, 0xec, 0x2f, 0x74, 0x93, 0x6c, 0x8b, 0x8c, 0xc7, 0xff, 0x3a, 0xb0, 0xb3, 0xb4, 0xcf,
|
||||
0x3e, 0x42, 0x79, 0x30, 0x0b, 0x91, 0x04, 0x6c, 0x76, 0xf8, 0xea, 0x2c, 0xed, 0xe4, 0xd7, 0x32,
|
||||
0x05, 0xf1, 0xad, 0x23, 0x17, 0xc3, 0x09, 0x26, 0x6d, 0xd3, 0xda, 0xc6, 0xce, 0xa3, 0xc0, 0xf7,
|
||||
0x4a, 0x4d, 0xa7, 0x55, 0x16, 0xb4, 0x66, 0x2f, 0xa1, 0x7e, 0xac, 0x70, 0x68, 0x70, 0x70, 0x73,
|
||||
0xee, 0x95, 0x69, 0x63, 0x1e, 0x60, 0x07, 0xe0, 0x12, 0x08, 0xe4, 0xd4, 0xab, 0x50, 0xa6, 0x0c,
|
||||
0xf3, 0xb7, 0xd0, 0xc8, 0x95, 0x65, 0x1b, 0xe0, 0xf6, 0xa7, 0xc3, 0x50, 0xdf, 0x49, 0xb3, 0x5d,
|
||||
0xb0, 0xa8, 0x2b, 0xe5, 0x68, 0x32, 0x54, 0xa3, 0x6d, 0x87, 0xff, 0x76, 0xa0, 0xd6, 0xc7, 0xa9,
|
||||
0xbf, 0x86, 0xaf, 0x56, 0xe4, 0x99, 0x92, 0x93, 0x54, 0xb8, 0x5d, 0xb3, 0x4d, 0x28, 0x0e, 0x24,
|
||||
0xc9, 0xae, 0x8b, 0xe2, 0x40, 0x2e, 0x5e, 0x6d, 0x79, 0xe9, 0x6a, 0x49, 0xb8, 0x9c, 0x84, 0x0a,
|
||||
0xb5, 0x26, 0xe1, 0xae, 0xc8, 0x30, 0xdb, 0x83, 0xca, 0x09, 0xfa, 0x51, 0xe8, 0x55, 0x69, 0x23,
|
||||
0x06, 0x6c, 0x1f, 0xaa, 0x27, 0x6a, 0x26, 0xa2, 0xa9, 0x57, 0xa3, 0x70, 0x82, 0xf8, 0x11, 0xb8,
|
||||
0x57, 0x4a, 0x86, 0xa8, 0xcc, 0x2c, 0x33, 0xd5, 0xc9, 0x99, 0xba, 0x07, 0x95, 0xeb, 0xe1, 0x38,
|
||||
0x4a, 0x9d, 0x8e, 0x01, 0xff, 0x95, 0x75, 0xac, 0x59, 0x0b, 0xb6, 0xbe, 0x68, 0xf4, 0xf3, 0x8a,
|
||||
0x1d, 0x2a, 0xb1, 0x18, 0x66, 0x1c, 0x36, 0x4e, 0x1f, 0x42, 0xbc, 0x35, 0xe8, 0xf7, 0x83, 0x9f,
|
||||
0x71, 0xca, 0x92, 0x78, 0x14, 0x63, 0xef, 0x01, 0x12, 0x3d, 0x01, 0x6a, 0xaf, 0x44, 0x1f, 0xd7,
|
||||
0x33, 0xfa, 0x2c, 0x52, 0x99, 0x22, 0x47, 0xe0, 0x37, 0x00, 0x02, 0x6f, 0x31, 0xf8, 0x81, 0xeb,
|
||||
0x98, 0xff, 0x0e, 0xb6, 0x8f, 0xc7, 0x38, 0x54, 0x8b, 0x83, 0xe3, 0x8a, 0xa5, 0x38, 0xdf, 0xc8,
|
||||
0x65, 0xd6, 0x7c, 0x04, 0xbb, 0x27, 0xa8, 0x8d, 0x92, 0xb3, 0xf4, 0x2b, 0x58, 0x67, 0x8a, 0xd8,
|
||||
0x11, 0xd4, 0x33, 0xbe, 0x57, 0x7c, 0x72, 0x52, 0xe6, 0x44, 0xfe, 0x15, 0xd8, 0x42, 0xb1, 0x64,
|
||||
0xe8, 0x52, 0x48, 0x95, 0x9e, 0x18, 0xba, 0x94, 0x67, 0x6f, 0xef, 0x54, 0x29, 0xa9, 0xd2, 0xdb,
|
||||
0x23, 0xc0, 0x7b, 0xab, 0x9a, 0xb1, 0xcf, 0x54, 0xcd, 0x1a, 0x30, 0x36, 0xe9, 0x50, 0x3f, 0xa7,
|
||||
0xfc, 0xcb, 0x52, 0x44, 0xca, 0xe3, 0x7f, 0x1c, 0xd8, 0x13, 0x18, 0x8e, 0x83, 0x5b, 0x1a, 0x9a,
|
||||
0xe3, 0x48, 0x69, 0xa9, 0xd6, 0x31, 0xe6, 0x10, 0x4a, 0xdf, 0xd1, 0x90, 0xac, 0x46, 0xe7, 0x35,
|
||||
0xd5, 0x59, 0x95, 0xa7, 0x7d, 0x8e, 0xe6, 0x32, 0xec, 0x15, 0x84, 0x65, 0xdb, 0x43, 0x1a, 0x0d,
|
||||
0x0d, 0xca, 0x93, 0x87, 0xfa, 0xe9, 0x21, 0x8d, 0xe6, 0xa0, 0x06, 0x15, 0x4a, 0x72, 0xf0, 0x06,
|
||||
0x2a, 0xb4, 0x61, 0x87, 0x27, 0x33, 0x32, 0xf6, 0x25, 0xc3, 0xdd, 0x32, 0x14, 0x65, 0xc8, 0x07,
|
||||
0x2b, 0xbb, 0xb2, 0xa3, 0x15, 0xbf, 0x30, 0xb6, 0x9f, 0x72, 0xaf, 0x90, 0xbd, 0x31, 0xee, 0x85,
|
||||
0x34, 0xf8, 0x10, 0xe8, 0x38, 0x9f, 0xdb, 0x2b, 0x88, 0x2c, 0xd2, 0x75, 0xa1, 0x1a, 0xbb, 0xf5,
|
||||
0xad, 0x4a, 0x7f, 0x1e, 0x87, 0xff, 0x02, 0x00, 0x00, 0xff, 0xff, 0x66, 0x74, 0x36, 0x3a, 0x49,
|
||||
0x06, 0x00, 0x00,
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// ReplicationClient is the client API for Replication service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type ReplicationClient interface {
|
||||
ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error)
|
||||
ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error)
|
||||
DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error)
|
||||
ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error)
|
||||
}
|
||||
|
||||
type replicationClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewReplicationClient(cc *grpc.ClientConn) ReplicationClient {
|
||||
return &replicationClient{cc}
|
||||
}
|
||||
|
||||
func (c *replicationClient) ListFilesystems(ctx context.Context, in *ListFilesystemReq, opts ...grpc.CallOption) (*ListFilesystemRes, error) {
|
||||
out := new(ListFilesystemRes)
|
||||
err := c.cc.Invoke(ctx, "/Replication/ListFilesystems", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationClient) ListFilesystemVersions(ctx context.Context, in *ListFilesystemVersionsReq, opts ...grpc.CallOption) (*ListFilesystemVersionsRes, error) {
|
||||
out := new(ListFilesystemVersionsRes)
|
||||
err := c.cc.Invoke(ctx, "/Replication/ListFilesystemVersions", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationClient) DestroySnapshots(ctx context.Context, in *DestroySnapshotsReq, opts ...grpc.CallOption) (*DestroySnapshotsRes, error) {
|
||||
out := new(DestroySnapshotsRes)
|
||||
err := c.cc.Invoke(ctx, "/Replication/DestroySnapshots", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *replicationClient) ReplicationCursor(ctx context.Context, in *ReplicationCursorReq, opts ...grpc.CallOption) (*ReplicationCursorRes, error) {
|
||||
out := new(ReplicationCursorRes)
|
||||
err := c.cc.Invoke(ctx, "/Replication/ReplicationCursor", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ReplicationServer is the server API for Replication service.
|
||||
type ReplicationServer interface {
|
||||
ListFilesystems(context.Context, *ListFilesystemReq) (*ListFilesystemRes, error)
|
||||
ListFilesystemVersions(context.Context, *ListFilesystemVersionsReq) (*ListFilesystemVersionsRes, error)
|
||||
DestroySnapshots(context.Context, *DestroySnapshotsReq) (*DestroySnapshotsRes, error)
|
||||
ReplicationCursor(context.Context, *ReplicationCursorReq) (*ReplicationCursorRes, error)
|
||||
}
|
||||
|
||||
func RegisterReplicationServer(s *grpc.Server, srv ReplicationServer) {
|
||||
s.RegisterService(&_Replication_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Replication_ListFilesystems_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListFilesystemReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServer).ListFilesystems(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/Replication/ListFilesystems",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServer).ListFilesystems(ctx, req.(*ListFilesystemReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Replication_ListFilesystemVersions_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListFilesystemVersionsReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServer).ListFilesystemVersions(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/Replication/ListFilesystemVersions",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServer).ListFilesystemVersions(ctx, req.(*ListFilesystemVersionsReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Replication_DestroySnapshots_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(DestroySnapshotsReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServer).DestroySnapshots(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/Replication/DestroySnapshots",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServer).DestroySnapshots(ctx, req.(*DestroySnapshotsReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Replication_ReplicationCursor_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ReplicationCursorReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ReplicationServer).ReplicationCursor(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/Replication/ReplicationCursor",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ReplicationServer).ReplicationCursor(ctx, req.(*ReplicationCursorReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Replication_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "Replication",
|
||||
HandlerType: (*ReplicationServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "ListFilesystems",
|
||||
Handler: _Replication_ListFilesystems_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListFilesystemVersions",
|
||||
Handler: _Replication_ListFilesystemVersions_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "DestroySnapshots",
|
||||
Handler: _Replication_DestroySnapshots_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ReplicationCursor",
|
||||
Handler: _Replication_ReplicationCursor_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "pdu.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("pdu.proto", fileDescriptor_pdu_89315d819a6e0938) }
|
||||
|
||||
var fileDescriptor_pdu_89315d819a6e0938 = []byte{
|
||||
// 735 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x55, 0xdd, 0x6e, 0xda, 0x4a,
|
||||
0x10, 0xc6, 0x60, 0xc0, 0x0c, 0x51, 0x42, 0x36, 0x9c, 0xc8, 0xc7, 0xe7, 0x28, 0x42, 0xdb, 0x1b,
|
||||
0x52, 0xa9, 0x6e, 0x45, 0x7b, 0x53, 0x55, 0xaa, 0x54, 0x42, 0x7e, 0xa4, 0x56, 0x69, 0xb4, 0xd0,
|
||||
0x28, 0xca, 0x1d, 0x0d, 0xa3, 0xc4, 0x0a, 0xb0, 0xce, 0xee, 0xba, 0x0a, 0xbd, 0xec, 0x7b, 0xf4,
|
||||
0x41, 0xfa, 0x0e, 0xbd, 0xec, 0x03, 0x55, 0xbb, 0x60, 0xe3, 0x60, 0x23, 0x71, 0xe5, 0xfd, 0xbe,
|
||||
0x9d, 0x9d, 0x9d, 0xf9, 0x76, 0x66, 0x0c, 0xb5, 0x70, 0x14, 0xf9, 0xa1, 0xe0, 0x8a, 0xd3, 0x3d,
|
||||
0xd8, 0xfd, 0x14, 0x48, 0x75, 0x12, 0x8c, 0x51, 0xce, 0xa4, 0xc2, 0x09, 0xc3, 0x07, 0x7a, 0x95,
|
||||
0x25, 0x25, 0x79, 0x01, 0xf5, 0x25, 0x21, 0x5d, 0xab, 0x55, 0x6a, 0xd7, 0x3b, 0x75, 0x3f, 0x65,
|
||||
0x94, 0xde, 0x27, 0x4d, 0x28, 0x1f, 0x4f, 0x42, 0x35, 0x73, 0x8b, 0x2d, 0xab, 0xed, 0xb0, 0x39,
|
||||
0xa0, 0x5d, 0x80, 0xa5, 0x11, 0x21, 0x60, 0x5f, 0x0c, 0xd5, 0x9d, 0x6b, 0xb5, 0xac, 0x76, 0x8d,
|
||||
0x99, 0x35, 0x69, 0x41, 0x9d, 0xa1, 0x8c, 0x26, 0x38, 0xe0, 0xf7, 0x38, 0x35, 0xa7, 0x6b, 0x2c,
|
||||
0x4d, 0xd1, 0x77, 0xf0, 0xef, 0xd3, 0xe8, 0x2e, 0x51, 0xc8, 0x80, 0x4f, 0x25, 0xc3, 0x07, 0x72,
|
||||
0x90, 0xbe, 0x60, 0xe1, 0x38, 0xc5, 0xd0, 0x8f, 0xeb, 0x0f, 0x4b, 0xe2, 0x83, 0x13, 0xc3, 0x45,
|
||||
0x7e, 0xc4, 0xcf, 0x58, 0xb2, 0xc4, 0x86, 0xfe, 0xb1, 0x60, 0x37, 0xb3, 0x4f, 0x3a, 0x60, 0x0f,
|
||||
0x66, 0x21, 0x9a, 0xcb, 0xb7, 0x3b, 0x07, 0x59, 0x0f, 0xfe, 0xe2, 0xab, 0xad, 0x98, 0xb1, 0xd5,
|
||||
0x4a, 0x9c, 0x0f, 0x27, 0xb8, 0x48, 0xd7, 0xac, 0x35, 0x77, 0x1a, 0x05, 0x23, 0xb7, 0xd4, 0xb2,
|
||||
0xda, 0x36, 0x33, 0x6b, 0xf2, 0x3f, 0xd4, 0x8e, 0x04, 0x0e, 0x15, 0x0e, 0xae, 0x4e, 0x5d, 0xdb,
|
||||
0x6c, 0x2c, 0x09, 0xe2, 0x81, 0x63, 0x40, 0xc0, 0xa7, 0x6e, 0xd9, 0x78, 0x4a, 0x30, 0x3d, 0x84,
|
||||
0x7a, 0xea, 0x5a, 0xb2, 0x05, 0x4e, 0x7f, 0x3a, 0x0c, 0xe5, 0x1d, 0x57, 0x8d, 0x82, 0x46, 0x5d,
|
||||
0xce, 0xef, 0x27, 0x43, 0x71, 0xdf, 0xb0, 0xe8, 0x2f, 0x0b, 0xaa, 0x7d, 0x9c, 0x8e, 0x36, 0xd0,
|
||||
0x53, 0x07, 0x79, 0x22, 0xf8, 0x24, 0x0e, 0x5c, 0xaf, 0xc9, 0x36, 0x14, 0x07, 0xdc, 0x84, 0x5d,
|
||||
0x63, 0xc5, 0x01, 0x5f, 0x7d, 0x52, 0x3b, 0xf3, 0xa4, 0x26, 0x70, 0x3e, 0x09, 0x05, 0x4a, 0x69,
|
||||
0x02, 0x77, 0x58, 0x82, 0x75, 0x21, 0xf5, 0x70, 0x14, 0x85, 0x6e, 0x65, 0x5e, 0x48, 0x06, 0x90,
|
||||
0x7d, 0xa8, 0xf4, 0xc4, 0x8c, 0x45, 0x53, 0xb7, 0x6a, 0xe8, 0x05, 0xa2, 0x6f, 0xc0, 0xb9, 0x10,
|
||||
0x3c, 0x44, 0xa1, 0x66, 0x89, 0xa8, 0x56, 0x4a, 0xd4, 0x26, 0x94, 0x2f, 0x87, 0xe3, 0x28, 0x56,
|
||||
0x7a, 0x0e, 0xe8, 0x8f, 0x24, 0x63, 0x49, 0xda, 0xb0, 0xf3, 0x45, 0xe2, 0x68, 0xb5, 0x08, 0x1d,
|
||||
0xb6, 0x4a, 0x13, 0x0a, 0x5b, 0xc7, 0x8f, 0x21, 0xde, 0x28, 0x1c, 0xf5, 0x83, 0xef, 0x68, 0x32,
|
||||
0x2e, 0xb1, 0x27, 0x1c, 0x39, 0x04, 0x58, 0xc4, 0x13, 0xa0, 0x74, 0x6d, 0x53, 0x54, 0x35, 0x3f,
|
||||
0x0e, 0x91, 0xa5, 0x36, 0xe9, 0x15, 0x00, 0xc3, 0x1b, 0x0c, 0xbe, 0xe1, 0x26, 0xc2, 0x3f, 0x87,
|
||||
0xc6, 0xd1, 0x18, 0x87, 0x22, 0x1b, 0x67, 0x86, 0xa7, 0x5b, 0x29, 0xcf, 0x92, 0xde, 0xc2, 0x5e,
|
||||
0x0f, 0xa5, 0x12, 0x7c, 0x16, 0x57, 0xc0, 0x26, 0x9d, 0x43, 0x5e, 0x41, 0x2d, 0xb1, 0x77, 0x8b,
|
||||
0x6b, 0xbb, 0x63, 0x69, 0x44, 0xaf, 0x81, 0xac, 0x5c, 0xb4, 0x68, 0xb2, 0x18, 0x9a, 0x5b, 0xd6,
|
||||
0x34, 0x59, 0x6c, 0x63, 0x06, 0x89, 0x10, 0x5c, 0xc4, 0x2f, 0x66, 0x00, 0xed, 0xe5, 0x25, 0xa1,
|
||||
0x87, 0x54, 0x55, 0x27, 0x3e, 0x56, 0x71, 0x03, 0xef, 0xf9, 0xd9, 0x10, 0x58, 0x6c, 0x43, 0x7f,
|
||||
0x5b, 0xd0, 0x64, 0x18, 0x8e, 0x83, 0x1b, 0xd3, 0x24, 0x47, 0x91, 0x90, 0x5c, 0x6c, 0x22, 0xc6,
|
||||
0x4b, 0x28, 0xdd, 0xa2, 0x32, 0x21, 0xd5, 0x3b, 0xff, 0xf9, 0x79, 0x3e, 0xfc, 0x53, 0x54, 0x9f,
|
||||
0xc3, 0xb3, 0x02, 0xd3, 0x96, 0xfa, 0x80, 0x44, 0x65, 0x4a, 0x64, 0xed, 0x81, 0x7e, 0x7c, 0x40,
|
||||
0xa2, 0xf2, 0xaa, 0x50, 0x36, 0x0e, 0xbc, 0x67, 0x50, 0x36, 0x1b, 0xba, 0x49, 0x12, 0xe1, 0xe6,
|
||||
0x5a, 0x24, 0xb8, 0x6b, 0x43, 0x91, 0x87, 0x74, 0x90, 0x9b, 0x8d, 0x6e, 0xa1, 0xf9, 0x24, 0xd1,
|
||||
0x79, 0xd8, 0x67, 0x85, 0x64, 0x96, 0x38, 0xe7, 0x5c, 0xe1, 0x63, 0x20, 0xe7, 0xfe, 0x9c, 0xb3,
|
||||
0x02, 0x4b, 0x98, 0xae, 0x03, 0x95, 0xb9, 0x4a, 0x9d, 0x9f, 0x45, 0xdd, 0xbf, 0x89, 0x5b, 0xf2,
|
||||
0x16, 0x76, 0x9e, 0x8e, 0x50, 0x49, 0x88, 0x9f, 0xf9, 0x89, 0x78, 0x59, 0x4e, 0x92, 0x0b, 0xd8,
|
||||
0xcf, 0x9f, 0xbe, 0xc4, 0xf3, 0xd7, 0xce, 0x74, 0x6f, 0xfd, 0x9e, 0x24, 0xef, 0xa1, 0xb1, 0x5a,
|
||||
0x07, 0xa4, 0xe9, 0xe7, 0xd4, 0xb7, 0x97, 0xc7, 0x4a, 0xf2, 0x01, 0x76, 0x33, 0x92, 0x91, 0x7f,
|
||||
0x72, 0xdf, 0xc7, 0xcb, 0xa5, 0x65, 0xb7, 0x7c, 0x5d, 0x0a, 0x47, 0xd1, 0xd7, 0x8a, 0xf9, 0xa1,
|
||||
0xbe, 0xfe, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xba, 0x8e, 0x63, 0x5d, 0x07, 0x00, 0x00,
|
||||
}
|
||||
|
@ -1,11 +1,19 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "pdu";
|
||||
|
||||
package pdu;
|
||||
service Replication {
|
||||
rpc ListFilesystems (ListFilesystemReq) returns (ListFilesystemRes);
|
||||
rpc ListFilesystemVersions (ListFilesystemVersionsReq) returns (ListFilesystemVersionsRes);
|
||||
rpc DestroySnapshots (DestroySnapshotsReq) returns (DestroySnapshotsRes);
|
||||
rpc ReplicationCursor (ReplicationCursorReq) returns (ReplicationCursorRes);
|
||||
// for Send and Recv, see package rpc
|
||||
}
|
||||
|
||||
message ListFilesystemReq {}
|
||||
|
||||
message ListFilesystemRes {
|
||||
repeated Filesystem Filesystems = 1;
|
||||
bool Empty = 2;
|
||||
}
|
||||
|
||||
message Filesystem {
|
||||
@ -60,22 +68,18 @@ message Property {
|
||||
}
|
||||
|
||||
message SendRes {
|
||||
// The actual stream is in the stream part of the streamrpc response
|
||||
|
||||
// Whether the resume token provided in the request has been used or not.
|
||||
bool UsedResumeToken = 1;
|
||||
bool UsedResumeToken = 2;
|
||||
|
||||
// Expected stream size determined by dry run, not exact.
|
||||
// 0 indicates that for the given SendReq, no size estimate could be made.
|
||||
int64 ExpectedSize = 2;
|
||||
int64 ExpectedSize = 3;
|
||||
|
||||
repeated Property Properties = 3;
|
||||
repeated Property Properties = 4;
|
||||
}
|
||||
|
||||
message ReceiveReq {
|
||||
// The stream part of the streamrpc request contains the zfs send stream
|
||||
|
||||
string Filesystem = 1;
|
||||
string Filesystem = 1; // FIXME should be snapshot name, we can enforce that on recv
|
||||
|
||||
// If true, the receiver should clear the resume token before perfoming the zfs recv of the stream in the request
|
||||
bool ClearResumeToken = 2;
|
||||
|
169
rpc/dataconn/base2bufpool/base2bufpool.go
Normal file
169
rpc/dataconn/base2bufpool/base2bufpool.go
Normal file
@ -0,0 +1,169 @@
|
||||
package base2bufpool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/bits"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type pool struct {
|
||||
mtx sync.Mutex
|
||||
bufs [][]byte
|
||||
shift uint
|
||||
}
|
||||
|
||||
func (p *pool) Put(buf []byte) {
|
||||
p.mtx.Lock()
|
||||
defer p.mtx.Unlock()
|
||||
if len(buf) != 1<<p.shift {
|
||||
panic(fmt.Sprintf("implementation error: %v %v", len(buf), 1<<p.shift))
|
||||
}
|
||||
if len(p.bufs) > 10 { // FIXME constant
|
||||
return
|
||||
}
|
||||
p.bufs = append(p.bufs, buf)
|
||||
}
|
||||
|
||||
func (p *pool) Get() []byte {
|
||||
p.mtx.Lock()
|
||||
defer p.mtx.Unlock()
|
||||
if len(p.bufs) > 0 {
|
||||
ret := p.bufs[len(p.bufs)-1]
|
||||
p.bufs = p.bufs[0 : len(p.bufs)-1]
|
||||
return ret
|
||||
}
|
||||
return make([]byte, 1<<p.shift)
|
||||
}
|
||||
|
||||
type Pool struct {
|
||||
minShift uint
|
||||
maxShift uint
|
||||
pools []pool
|
||||
onNoFit NoFitBehavior
|
||||
}
|
||||
|
||||
type Buffer struct {
|
||||
// always power of 2, from Pool.pools
|
||||
shiftBuf []byte
|
||||
// presentedLen
|
||||
payloadLen uint
|
||||
// backref too pool for Free
|
||||
pool *Pool
|
||||
}
|
||||
|
||||
func (b Buffer) Bytes() []byte {
|
||||
return b.shiftBuf[0:b.payloadLen]
|
||||
}
|
||||
|
||||
func (b *Buffer) Shrink(newPayloadLen uint) {
|
||||
if newPayloadLen > b.payloadLen {
|
||||
panic(fmt.Sprintf("shrink is actually an expand, invalid: %v %v", newPayloadLen, b.payloadLen))
|
||||
}
|
||||
b.payloadLen = newPayloadLen
|
||||
}
|
||||
|
||||
func (b Buffer) Free() {
|
||||
if b.pool != nil {
|
||||
b.pool.put(b)
|
||||
}
|
||||
}
|
||||
|
||||
//go:generate enumer -type NoFitBehavior
|
||||
type NoFitBehavior uint
|
||||
|
||||
const (
|
||||
AllocateSmaller NoFitBehavior = 1 << iota
|
||||
AllocateLarger
|
||||
|
||||
Allocate NoFitBehavior = AllocateSmaller | AllocateLarger
|
||||
Panic NoFitBehavior = 0
|
||||
)
|
||||
|
||||
func New(minShift, maxShift uint, noFitBehavior NoFitBehavior) *Pool {
|
||||
if minShift > 63 || maxShift > 63 {
|
||||
panic(fmt.Sprintf("{min|max}Shift are the _exponent_, got minShift=%v maxShift=%v and limit of 63, which amounts to %v bits", minShift, maxShift, uint64(1)<<63))
|
||||
}
|
||||
pools := make([]pool, maxShift-minShift+1)
|
||||
for i := uint(0); i < uint(len(pools)); i++ {
|
||||
i := i // the closure below must copy i
|
||||
pools[i] = pool{
|
||||
shift: minShift + i,
|
||||
bufs: make([][]byte, 0, 10),
|
||||
}
|
||||
}
|
||||
return &Pool{
|
||||
minShift: minShift,
|
||||
maxShift: maxShift,
|
||||
pools: pools,
|
||||
onNoFit: noFitBehavior,
|
||||
}
|
||||
}
|
||||
|
||||
func fittingShift(x uint) uint {
|
||||
if x == 0 {
|
||||
return 0
|
||||
}
|
||||
blen := uint(bits.Len(x))
|
||||
if 1<<(blen-1) == x {
|
||||
return blen - 1
|
||||
}
|
||||
return blen
|
||||
}
|
||||
|
||||
func (p *Pool) handlePotentialNoFit(reqShift uint) (buf Buffer, didHandle bool) {
|
||||
if reqShift == 0 {
|
||||
if p.onNoFit&AllocateSmaller != 0 {
|
||||
return Buffer{[]byte{}, 0, nil}, true
|
||||
} else {
|
||||
goto doPanic
|
||||
}
|
||||
}
|
||||
if reqShift < p.minShift {
|
||||
if p.onNoFit&AllocateSmaller != 0 {
|
||||
goto alloc
|
||||
} else {
|
||||
goto doPanic
|
||||
}
|
||||
}
|
||||
if reqShift > p.maxShift {
|
||||
if p.onNoFit&AllocateLarger != 0 {
|
||||
goto alloc
|
||||
} else {
|
||||
goto doPanic
|
||||
}
|
||||
}
|
||||
return Buffer{}, false
|
||||
alloc:
|
||||
return Buffer{make([]byte, 1<<reqShift), 1 << reqShift, nil}, true
|
||||
doPanic:
|
||||
panic(fmt.Sprintf("base2bufpool: configured to panic on shift=%v (minShift=%v maxShift=%v)", reqShift, p.minShift, p.maxShift))
|
||||
}
|
||||
|
||||
func (p *Pool) Get(minSize uint) Buffer {
|
||||
shift := fittingShift(minSize)
|
||||
buf, didHandle := p.handlePotentialNoFit(shift)
|
||||
if didHandle {
|
||||
buf.Shrink(minSize)
|
||||
return buf
|
||||
}
|
||||
idx := int64(shift) - int64(p.minShift)
|
||||
return Buffer{p.pools[idx].Get(), minSize, p}
|
||||
}
|
||||
|
||||
func (p *Pool) put(buffer Buffer) {
|
||||
if buffer.pool != p {
|
||||
panic("putting buffer to pool where it didn't originate from")
|
||||
}
|
||||
buf := buffer.shiftBuf
|
||||
if bits.OnesCount(uint(len(buf))) > 1 {
|
||||
panic(fmt.Sprintf("putting buffer that is not power of two len: %v", len(buf)))
|
||||
}
|
||||
if len(buf) == 0 {
|
||||
return
|
||||
}
|
||||
shift := fittingShift(uint(len(buf)))
|
||||
if shift < p.minShift || shift > p.maxShift {
|
||||
return // drop it
|
||||
}
|
||||
p.pools[shift-p.minShift].Put(buf)
|
||||
}
|
98
rpc/dataconn/base2bufpool/base2bufpool_test.go
Normal file
98
rpc/dataconn/base2bufpool/base2bufpool_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package base2bufpool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPoolAllocBehavior(t *testing.T) {
|
||||
|
||||
type testcase struct {
|
||||
poolMinShift, poolMaxShift uint
|
||||
behavior NoFitBehavior
|
||||
get uint
|
||||
expShiftBufLen int64 // -1 if panic expected
|
||||
}
|
||||
|
||||
tcs := []testcase{
|
||||
{
|
||||
15, 20, Allocate,
|
||||
1 << 14, 1 << 14,
|
||||
},
|
||||
{
|
||||
15, 20, Allocate,
|
||||
1 << 22, 1 << 22,
|
||||
},
|
||||
{
|
||||
15, 20, Panic,
|
||||
1 << 16, 1 << 16,
|
||||
},
|
||||
{
|
||||
15, 20, Panic,
|
||||
1 << 14, -1,
|
||||
},
|
||||
{
|
||||
15, 20, Panic,
|
||||
1 << 22, -1,
|
||||
},
|
||||
{
|
||||
15, 20, Panic,
|
||||
(1 << 15) + 23, 1 << 16,
|
||||
},
|
||||
{
|
||||
15, 20, Panic,
|
||||
0, -1, // yep, 0 always works, even
|
||||
},
|
||||
{
|
||||
15, 20, Allocate,
|
||||
0, 0,
|
||||
},
|
||||
{
|
||||
15, 20, AllocateSmaller,
|
||||
1 << 14, 1 << 14,
|
||||
},
|
||||
{
|
||||
15, 20, AllocateSmaller,
|
||||
1 << 22, -1,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tcs {
|
||||
tc := tcs[i]
|
||||
t.Run(fmt.Sprintf("[%d,%d] behav=%s Get(%d) exp=%d", tc.poolMinShift, tc.poolMaxShift, tc.behavior, tc.get, tc.expShiftBufLen), func(t *testing.T) {
|
||||
pool := New(tc.poolMinShift, tc.poolMaxShift, tc.behavior)
|
||||
if tc.expShiftBufLen == -1 {
|
||||
assert.Panics(t, func() {
|
||||
pool.Get(tc.get)
|
||||
})
|
||||
return
|
||||
}
|
||||
buf := pool.Get(tc.get)
|
||||
assert.True(t, uint(len(buf.Bytes())) == tc.get)
|
||||
assert.True(t, int64(len(buf.shiftBuf)) == tc.expShiftBufLen)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFittingShift(t *testing.T) {
|
||||
assert.Equal(t, uint(16), fittingShift(1+1<<15))
|
||||
assert.Equal(t, uint(15), fittingShift(1<<15))
|
||||
}
|
||||
|
||||
func TestFreeFromPoolRangeDoesNotPanic(t *testing.T) {
|
||||
pool := New(15, 20, Allocate)
|
||||
buf := pool.Get(1 << 16)
|
||||
assert.NotPanics(t, func() {
|
||||
buf.Free()
|
||||
})
|
||||
}
|
||||
|
||||
func TestFreeFromOutOfPoolRangeDoesNotPanic(t *testing.T) {
|
||||
pool := New(15, 20, Allocate)
|
||||
buf := pool.Get(1 << 23)
|
||||
assert.NotPanics(t, func() {
|
||||
buf.Free()
|
||||
})
|
||||
}
|
51
rpc/dataconn/base2bufpool/nofitbehavior_enumer.go
Normal file
51
rpc/dataconn/base2bufpool/nofitbehavior_enumer.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Code generated by "enumer -type NoFitBehavior"; DO NOT EDIT.
|
||||
|
||||
package base2bufpool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const _NoFitBehaviorName = "PanicAllocateSmallerAllocateLargerAllocate"
|
||||
|
||||
var _NoFitBehaviorIndex = [...]uint8{0, 5, 20, 34, 42}
|
||||
|
||||
func (i NoFitBehavior) String() string {
|
||||
if i >= NoFitBehavior(len(_NoFitBehaviorIndex)-1) {
|
||||
return fmt.Sprintf("NoFitBehavior(%d)", i)
|
||||
}
|
||||
return _NoFitBehaviorName[_NoFitBehaviorIndex[i]:_NoFitBehaviorIndex[i+1]]
|
||||
}
|
||||
|
||||
var _NoFitBehaviorValues = []NoFitBehavior{0, 1, 2, 3}
|
||||
|
||||
var _NoFitBehaviorNameToValueMap = map[string]NoFitBehavior{
|
||||
_NoFitBehaviorName[0:5]: 0,
|
||||
_NoFitBehaviorName[5:20]: 1,
|
||||
_NoFitBehaviorName[20:34]: 2,
|
||||
_NoFitBehaviorName[34:42]: 3,
|
||||
}
|
||||
|
||||
// NoFitBehaviorString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func NoFitBehaviorString(s string) (NoFitBehavior, error) {
|
||||
if val, ok := _NoFitBehaviorNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to NoFitBehavior values", s)
|
||||
}
|
||||
|
||||
// NoFitBehaviorValues returns all values of the enum
|
||||
func NoFitBehaviorValues() []NoFitBehavior {
|
||||
return _NoFitBehaviorValues
|
||||
}
|
||||
|
||||
// IsANoFitBehavior returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i NoFitBehavior) IsANoFitBehavior() bool {
|
||||
for _, v := range _NoFitBehaviorValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
215
rpc/dataconn/dataconn_client.go
Normal file
215
rpc/dataconn/dataconn_client.go
Normal file
@ -0,0 +1,215 @@
|
||||
package dataconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
log Logger
|
||||
cn transport.Connecter
|
||||
}
|
||||
|
||||
func NewClient(connecter transport.Connecter, log Logger) *Client {
|
||||
return &Client{
|
||||
log: log,
|
||||
cn: connecter,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error {
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, memErr := buf.WriteString(endpoint)
|
||||
if memErr != nil {
|
||||
panic(memErr)
|
||||
}
|
||||
if err := conn.WriteStreamedMessage(ctx, &buf, ReqHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protobufBytes, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
protobuf := bytes.NewBuffer(protobufBytes)
|
||||
if err := conn.WriteStreamedMessage(ctx, protobuf, ReqStructured); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if streamCopier != nil {
|
||||
return conn.SendStream(ctx, streamCopier, ZFSStream)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type RemoteHandlerError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *RemoteHandlerError) Error() string {
|
||||
return fmt.Sprintf("server error: %s", e.msg)
|
||||
}
|
||||
|
||||
type ProtocolError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *ProtocolError) Error() string {
|
||||
return fmt.Sprintf("protocol error: %s", e)
|
||||
}
|
||||
|
||||
func (c *Client) recv(ctx context.Context, conn *stream.Conn, res proto.Message) error {
|
||||
|
||||
headerBuf, err := conn.ReadStreamedMessage(ctx, ResponseHeaderMaxSize, ResHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header := string(headerBuf)
|
||||
if strings.HasPrefix(header, responseHeaderHandlerErrorPrefix) {
|
||||
// FIXME distinguishable error type
|
||||
return &RemoteHandlerError{strings.TrimPrefix(header, responseHeaderHandlerErrorPrefix)}
|
||||
}
|
||||
if !strings.HasPrefix(header, responseHeaderHandlerOk) {
|
||||
return &ProtocolError{fmt.Errorf("invalid header: %q", header)}
|
||||
}
|
||||
|
||||
protobuf, err := conn.ReadStreamedMessage(ctx, ResponseStructuredMaxSize, ResStructured)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := proto.Unmarshal(protobuf, res); err != nil {
|
||||
return &ProtocolError{fmt.Errorf("cannot unmarshal structured part of response: %s", err)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getWire(ctx context.Context) (*stream.Conn, error) {
|
||||
nc, err := c.cn.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *Client) putWire(conn *stream.Conn) {
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.WithError(err).Error("error closing connection")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
conn, err := c.getWire(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
putWireOnReturn := true
|
||||
defer func() {
|
||||
if putWireOnReturn {
|
||||
c.putWire(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.send(ctx, conn, EndpointSend, req, nil); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var res pdu.SendRes
|
||||
if err := c.recv(ctx, conn, &res); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var copier zfs.StreamCopier = nil
|
||||
if !req.DryRun {
|
||||
putWireOnReturn = false
|
||||
copier = &streamCopier{streamConn: conn, closeStreamOnClose: true}
|
||||
}
|
||||
|
||||
return &res, copier, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
|
||||
defer c.log.Info("ReqRecv returns")
|
||||
conn, err := c.getWire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// send and recv response concurrently to catch early exists of remote handler
|
||||
// (e.g. disk full, permission error, etc)
|
||||
|
||||
type recvRes struct {
|
||||
res *pdu.ReceiveRes
|
||||
err error
|
||||
}
|
||||
recvErrChan := make(chan recvRes)
|
||||
go func() {
|
||||
res := &pdu.ReceiveRes{}
|
||||
if err := c.recv(ctx, conn, res); err != nil {
|
||||
recvErrChan <- recvRes{res, err}
|
||||
} else {
|
||||
recvErrChan <- recvRes{res, nil}
|
||||
}
|
||||
}()
|
||||
|
||||
sendErrChan := make(chan error)
|
||||
go func() {
|
||||
if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil {
|
||||
sendErrChan <- err
|
||||
} else {
|
||||
sendErrChan <- nil
|
||||
}
|
||||
}()
|
||||
|
||||
var res recvRes
|
||||
var sendErr error
|
||||
var cause error // one of the above
|
||||
didTryClose := false
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case res = <-recvErrChan:
|
||||
c.log.WithField("errType", fmt.Sprintf("%T", res.err)).WithError(res.err).Debug("recv goroutine returned")
|
||||
if res.err != nil && cause == nil {
|
||||
cause = res.err
|
||||
}
|
||||
case sendErr = <-sendErrChan:
|
||||
c.log.WithField("errType", fmt.Sprintf("%T", sendErr)).WithError(sendErr).Debug("send goroutine returned")
|
||||
if sendErr != nil && cause == nil {
|
||||
cause = sendErr
|
||||
}
|
||||
}
|
||||
if !didTryClose && (res.err != nil || sendErr != nil) {
|
||||
didTryClose = true
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.WithError(err).Error("ReqRecv: cannot close connection, will likely block indefinitely")
|
||||
}
|
||||
c.log.WithError(err).Debug("ReqRecv: closed connection, should trigger other goroutine error")
|
||||
}
|
||||
}
|
||||
|
||||
if !didTryClose {
|
||||
// didn't close it in above loop, so we can give it back
|
||||
c.putWire(conn)
|
||||
}
|
||||
|
||||
// if receive failed with a RemoteHandlerError, we know the transport was not broken
|
||||
// => take the remote error as cause for the operation to fail
|
||||
// TODO combine errors if send also failed
|
||||
// (after all, send could have crashed on our side, rendering res.err a mere symptom of the cause)
|
||||
if _, ok := res.err.(*RemoteHandlerError); ok {
|
||||
cause = res.err
|
||||
}
|
||||
|
||||
return res.res, cause
|
||||
}
|
20
rpc/dataconn/dataconn_debug.go
Normal file
20
rpc/dataconn/dataconn_debug.go
Normal file
@ -0,0 +1,20 @@
|
||||
package dataconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var debugEnabled bool = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("ZREPL_RPC_DATACONN_DEBUG") != "" {
|
||||
debugEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func debug(format string, args ...interface{}) {
|
||||
if debugEnabled {
|
||||
fmt.Fprintf(os.Stderr, "rpc/dataconn: %s\n", fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
178
rpc/dataconn/dataconn_server.go
Normal file
178
rpc/dataconn/dataconn_server.go
Normal file
@ -0,0 +1,178 @@
|
||||
package dataconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// WireInterceptor has a chance to exchange the context and connection on each client connection.
|
||||
type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (context.Context, *transport.AuthConn)
|
||||
|
||||
// Handler implements the functionality that is exposed by Server to the Client.
|
||||
type Handler interface {
|
||||
// Send handles a SendRequest.
|
||||
// The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run.
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
|
||||
// Receive handles a ReceiveRequest.
|
||||
// It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout
|
||||
// configured in ServerConfig.Shared.IdleConnTimeout.
|
||||
Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
|
||||
}
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
type Server struct {
|
||||
h Handler
|
||||
wi WireInterceptor
|
||||
log Logger
|
||||
}
|
||||
|
||||
func NewServer(wi WireInterceptor, logger Logger, handler Handler) *Server {
|
||||
return &Server{
|
||||
h: handler,
|
||||
wi: wi,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Serve consumes the listener, closes it as soon as ctx is closed.
|
||||
// No accept errors are returned: they are logged to the Logger passed
|
||||
// to the constructor.
|
||||
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.log.Debug("context done")
|
||||
if err := l.Close(); err != nil {
|
||||
s.log.WithError(err).Error("cannot close listener")
|
||||
}
|
||||
}()
|
||||
conns := make(chan *transport.AuthConn)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept(ctx)
|
||||
if err != nil {
|
||||
if ctx.Done() != nil {
|
||||
s.log.Debug("stop accepting after context is done")
|
||||
return
|
||||
}
|
||||
s.log.WithError(err).Error("accept error")
|
||||
continue
|
||||
}
|
||||
conns <- conn
|
||||
}
|
||||
}()
|
||||
for conn := range conns {
|
||||
go s.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveConn(nc *transport.AuthConn) {
|
||||
s.log.Debug("serveConn begin")
|
||||
defer s.log.Debug("serveConn done")
|
||||
|
||||
ctx := context.Background()
|
||||
if s.wi != nil {
|
||||
ctx, nc = s.wi(ctx, nc)
|
||||
}
|
||||
|
||||
c := stream.Wrap(nc, HeartbeatInterval, HeartbeatPeerTimeout)
|
||||
defer func() {
|
||||
s.log.Debug("close client connection")
|
||||
if err := c.Close(); err != nil {
|
||||
s.log.WithError(err).Error("cannot close client connection")
|
||||
}
|
||||
}()
|
||||
|
||||
header, err := c.ReadStreamedMessage(ctx, RequestHeaderMaxSize, ReqHeader)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("error reading structured part")
|
||||
return
|
||||
}
|
||||
endpoint := string(header)
|
||||
|
||||
reqStructured, err := c.ReadStreamedMessage(ctx, RequestStructuredMaxSize, ReqStructured)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("error reading structured part")
|
||||
return
|
||||
}
|
||||
|
||||
s.log.WithField("endpoint", endpoint).Debug("calling handler")
|
||||
|
||||
var res proto.Message
|
||||
var sendStream zfs.StreamCopier
|
||||
var handlerErr error
|
||||
switch endpoint {
|
||||
case EndpointSend:
|
||||
var req pdu.SendReq
|
||||
if err := proto.Unmarshal(reqStructured, &req); err != nil {
|
||||
s.log.WithError(err).Error("cannot unmarshal send request")
|
||||
return
|
||||
}
|
||||
res, sendStream, handlerErr = s.h.Send(ctx, &req) // SHADOWING
|
||||
case EndpointRecv:
|
||||
var req pdu.ReceiveReq
|
||||
if err := proto.Unmarshal(reqStructured, &req); err != nil {
|
||||
s.log.WithError(err).Error("cannot unmarshal receive request")
|
||||
return
|
||||
}
|
||||
res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING
|
||||
default:
|
||||
s.log.WithField("endpoint", endpoint).Error("unknown endpoint")
|
||||
handlerErr = fmt.Errorf("requested endpoint does not exist")
|
||||
return
|
||||
}
|
||||
|
||||
s.log.WithField("endpoint", endpoint).WithField("errType", fmt.Sprintf("%T", handlerErr)).Debug("handler returned")
|
||||
|
||||
// prepare protobuf now to return the protobuf error in the header
|
||||
// if marshaling fails. We consider failed marshaling a handler error
|
||||
var protobuf *bytes.Buffer
|
||||
if handlerErr == nil {
|
||||
protobufBytes, err := proto.Marshal(res)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("cannot marshal handler protobuf")
|
||||
handlerErr = err
|
||||
}
|
||||
protobuf = bytes.NewBuffer(protobufBytes) // SHADOWING
|
||||
}
|
||||
|
||||
var resHeaderBuf bytes.Buffer
|
||||
if handlerErr == nil {
|
||||
resHeaderBuf.WriteString(responseHeaderHandlerOk)
|
||||
} else {
|
||||
resHeaderBuf.WriteString(responseHeaderHandlerErrorPrefix)
|
||||
resHeaderBuf.WriteString(handlerErr.Error())
|
||||
}
|
||||
if err := c.WriteStreamedMessage(ctx, &resHeaderBuf, ResHeader); err != nil {
|
||||
s.log.WithError(err).Error("cannot write response header")
|
||||
return
|
||||
}
|
||||
|
||||
if handlerErr != nil {
|
||||
s.log.Debug("early exit after handler error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.WriteStreamedMessage(ctx, protobuf, ResStructured); err != nil {
|
||||
s.log.WithError(err).Error("cannot write structured part of response")
|
||||
return
|
||||
}
|
||||
|
||||
if sendStream != nil {
|
||||
err := c.SendStream(ctx, sendStream, ZFSStream)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("cannot write send stream")
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
70
rpc/dataconn/dataconn_shared.go
Normal file
70
rpc/dataconn/dataconn_shared.go
Normal file
@ -0,0 +1,70 @@
|
||||
package dataconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
const (
|
||||
EndpointSend string = "/v1/send"
|
||||
EndpointRecv string = "/v1/recv"
|
||||
)
|
||||
|
||||
const (
|
||||
ReqHeader uint32 = 1 + iota
|
||||
ReqStructured
|
||||
ResHeader
|
||||
ResStructured
|
||||
ZFSStream
|
||||
)
|
||||
|
||||
// Note that changing theses constants may break interop with other clients
|
||||
// Aggressive with timing, conservative (future compatible) with buffer sizes
|
||||
const (
|
||||
HeartbeatInterval = 5 * time.Second
|
||||
HeartbeatPeerTimeout = 10 * time.Second
|
||||
RequestHeaderMaxSize = 1 << 15
|
||||
RequestStructuredMaxSize = 1 << 22
|
||||
ResponseHeaderMaxSize = 1 << 15
|
||||
ResponseStructuredMaxSize = 1 << 23
|
||||
)
|
||||
|
||||
// the following are protocol constants
|
||||
const (
|
||||
responseHeaderHandlerOk = "HANDLER OK\n"
|
||||
responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n"
|
||||
)
|
||||
|
||||
type streamCopier struct {
|
||||
mtx sync.Mutex
|
||||
used bool
|
||||
streamConn *stream.Conn
|
||||
closeStreamOnClose bool
|
||||
}
|
||||
|
||||
// WriteStreamTo implements zfs.StreamCopier
|
||||
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
if s.used {
|
||||
panic("streamCopier used mulitple times")
|
||||
}
|
||||
s.used = true
|
||||
return s.streamConn.ReadStreamInto(w, ZFSStream)
|
||||
}
|
||||
|
||||
// Close implements zfs.StreamCopier
|
||||
func (s *streamCopier) Close() error {
|
||||
// only record the close here, what we do actually depends on whether
|
||||
// the streamCopier is instantiated server-side or client-side
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
if s.closeStreamOnClose {
|
||||
return s.streamConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
1
rpc/dataconn/dataconn_test.go
Normal file
1
rpc/dataconn/dataconn_test.go
Normal file
@ -0,0 +1 @@
|
||||
package dataconn
|
346
rpc/dataconn/frameconn/frameconn.go
Normal file
346
rpc/dataconn/frameconn/frameconn.go
Normal file
@ -0,0 +1,346 @@
|
||||
package frameconn
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
)
|
||||
|
||||
type FrameHeader struct {
|
||||
Type uint32
|
||||
PayloadLen uint32
|
||||
}
|
||||
|
||||
// The 4 MSBs of ft are reserved for frameconn.
|
||||
func IsPublicFrameType(ft uint32) bool {
|
||||
return (0xf<<28)&ft == 0
|
||||
}
|
||||
|
||||
const (
|
||||
rstFrameType uint32 = 1<<28 + iota
|
||||
)
|
||||
|
||||
func assertPublicFrameType(frameType uint32) {
|
||||
if !IsPublicFrameType(frameType) {
|
||||
panic(fmt.Sprintf("frameconn: frame type %v cannot be used by consumers of this package", frameType))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FrameHeader) Unmarshal(buf []byte) {
|
||||
if len(buf) != 8 {
|
||||
panic(fmt.Sprintf("frame header is 8 bytes long"))
|
||||
}
|
||||
f.Type = binary.BigEndian.Uint32(buf[0:4])
|
||||
f.PayloadLen = binary.BigEndian.Uint32(buf[4:8])
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
readMtx, writeMtx sync.Mutex
|
||||
nc timeoutconn.Conn
|
||||
ncBuf *bufio.ReadWriter
|
||||
readNextValid bool
|
||||
readNext FrameHeader
|
||||
nextReadErr error
|
||||
bufPool *base2bufpool.Pool // no need for sync around it
|
||||
shutdown shutdownFSM
|
||||
}
|
||||
|
||||
func Wrap(nc timeoutconn.Conn) *Conn {
|
||||
return &Conn{
|
||||
nc: nc,
|
||||
// ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)),
|
||||
bufPool: base2bufpool.New(15, 22, base2bufpool.Allocate), // FIXME switch to Panic, but need to enforce the limits in recv for that. => need frameconn config
|
||||
readNext: FrameHeader{},
|
||||
readNextValid: false,
|
||||
}
|
||||
}
|
||||
|
||||
var ErrReadFrameLengthShort = errors.New("read frame length too short")
|
||||
var ErrFixedFrameLengthMismatch = errors.New("read frame length mismatch")
|
||||
|
||||
type Buffer struct {
|
||||
bufpoolBuffer base2bufpool.Buffer
|
||||
payloadLen uint32
|
||||
}
|
||||
|
||||
func (b *Buffer) Free() {
|
||||
b.bufpoolBuffer.Free()
|
||||
}
|
||||
|
||||
func (b *Buffer) Bytes() []byte {
|
||||
return b.bufpoolBuffer.Bytes()[0:b.payloadLen]
|
||||
}
|
||||
|
||||
type Frame struct {
|
||||
Header FrameHeader
|
||||
Buffer Buffer
|
||||
}
|
||||
|
||||
var ErrShutdown = fmt.Errorf("frameconn: shutting down")
|
||||
|
||||
// ReadFrame reads a frame from the connection.
|
||||
//
|
||||
// Due to an internal optimization (Readv, specifically), it is not guaranteed that a single call to
|
||||
// WriteFrame unblocks a pending ReadFrame on an otherwise idle (empty) connection.
|
||||
// The only way to guarantee that all previously written frames can reach the peer's layers on top
|
||||
// of frameconn is to send an empty frame (no payload) and to ignore empty frames on the receiving side.
|
||||
func (c *Conn) ReadFrame() (Frame, error) {
|
||||
|
||||
if c.shutdown.IsShuttingDown() {
|
||||
return Frame{}, ErrShutdown
|
||||
}
|
||||
|
||||
// only aquire readMtx now to prioritize the draining in Shutdown()
|
||||
// over external callers (= drain public callers)
|
||||
|
||||
c.readMtx.Lock()
|
||||
defer c.readMtx.Unlock()
|
||||
f, err := c.readFrame()
|
||||
if f.Header.Type == rstFrameType {
|
||||
c.shutdown.Begin()
|
||||
return Frame{}, ErrShutdown
|
||||
}
|
||||
return f, err
|
||||
}
|
||||
|
||||
// callers must have readMtx locked
|
||||
func (c *Conn) readFrame() (Frame, error) {
|
||||
|
||||
if c.nextReadErr != nil {
|
||||
ret := c.nextReadErr
|
||||
c.nextReadErr = nil
|
||||
return Frame{}, ret
|
||||
}
|
||||
|
||||
if !c.readNextValid {
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(c.nc, buf[:]); err != nil {
|
||||
return Frame{}, err
|
||||
}
|
||||
c.readNext.Unmarshal(buf[:])
|
||||
c.readNextValid = true
|
||||
}
|
||||
|
||||
// read payload + next header
|
||||
var nextHdrBuf [8]byte
|
||||
buffer := c.bufPool.Get(uint(c.readNext.PayloadLen))
|
||||
bufferBytes := buffer.Bytes()
|
||||
|
||||
if c.readNext.PayloadLen == 0 {
|
||||
// This if statement implements the unlock-by-sending-empty-frame behavior
|
||||
// documented in ReadFrame's public docs.
|
||||
//
|
||||
// It is crucial that we return this empty frame now:
|
||||
// Consider the following plot with x-axis being time,
|
||||
// P being a frame with payload, E one without, X either of P or E
|
||||
//
|
||||
// P P P P P P P E.....................X
|
||||
// | | | |
|
||||
// | | | F3
|
||||
// | | |
|
||||
// | F2 |signficant time between frames because
|
||||
// F1 the peer has nothing to say to us
|
||||
//
|
||||
// Assume we're at the point were F2's header is in c.readNext.
|
||||
// That means F2 has not yet been returned.
|
||||
// But because it is empty (no payload), we're already done reading it.
|
||||
// If we omitted this if statement, the following would happen:
|
||||
// Readv below would read [][]byte{[len(0)], [len(8)]).
|
||||
|
||||
c.readNextValid = false
|
||||
frame := Frame{
|
||||
Header: c.readNext,
|
||||
Buffer: Buffer{
|
||||
bufpoolBuffer: buffer,
|
||||
payloadLen: c.readNext.PayloadLen, // 0
|
||||
},
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
noNextHeader := false
|
||||
if n, err := c.nc.ReadvFull([][]byte{bufferBytes, nextHdrBuf[:]}); err != nil {
|
||||
noNextHeader = true
|
||||
zeroPayloadAndPeerClosed := n == 0 && c.readNext.PayloadLen == 0 && err == io.EOF
|
||||
zeroPayloadAndNextFrameHeaderThenPeerClosed := err == io.EOF && c.readNext.PayloadLen == 0 && n == int64(len(nextHdrBuf))
|
||||
nonzeroPayloadRecvdButNextHeaderMissing := n > 0 && uint32(n) == c.readNext.PayloadLen
|
||||
if zeroPayloadAndPeerClosed || zeroPayloadAndNextFrameHeaderThenPeerClosed || nonzeroPayloadRecvdButNextHeaderMissing {
|
||||
// This is the last frame on the conn.
|
||||
// Store the error to be returned on the next invocation of ReadFrame.
|
||||
c.nextReadErr = err
|
||||
// NORETURN, this frame is still valid
|
||||
} else {
|
||||
return Frame{}, err
|
||||
}
|
||||
}
|
||||
|
||||
frame := Frame{
|
||||
Header: c.readNext,
|
||||
Buffer: Buffer{
|
||||
bufpoolBuffer: buffer,
|
||||
payloadLen: c.readNext.PayloadLen,
|
||||
},
|
||||
}
|
||||
|
||||
if !noNextHeader {
|
||||
c.readNext.Unmarshal(nextHdrBuf[:])
|
||||
c.readNextValid = true
|
||||
} else {
|
||||
c.readNextValid = false
|
||||
}
|
||||
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteFrame(payload []byte, frameType uint32) error {
|
||||
assertPublicFrameType(frameType)
|
||||
if c.shutdown.IsShuttingDown() {
|
||||
return ErrShutdown
|
||||
}
|
||||
c.writeMtx.Lock()
|
||||
defer c.writeMtx.Unlock()
|
||||
return c.writeFrame(payload, frameType)
|
||||
}
|
||||
|
||||
func (c *Conn) writeFrame(payload []byte, frameType uint32) error {
|
||||
var hdrBuf [8]byte
|
||||
binary.BigEndian.PutUint32(hdrBuf[0:4], frameType)
|
||||
binary.BigEndian.PutUint32(hdrBuf[4:8], uint32(len(payload)))
|
||||
bufs := net.Buffers([][]byte{hdrBuf[:], payload})
|
||||
if _, err := c.nc.WritevFull(bufs); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Shutdown(deadline time.Time) error {
|
||||
// TCP connection teardown is a bit wonky if we are in a situation
|
||||
// where there is still data in flight (DIF) to our side:
|
||||
// If we just close the connection, our kernel will send RSTs
|
||||
// in response to the DIF, and those RSTs may reach the client's
|
||||
// kernel faster than the client app is able to pull the
|
||||
// last bytes from its kernel TCP receive buffer.
|
||||
//
|
||||
// Therefore, we send a frame with type rstFrameType to indicate
|
||||
// that the connection is to be closed immediately, and further
|
||||
// use CloseWrite instead of Close.
|
||||
// As per definition of the wire interface, CloseWrite guarantees
|
||||
// delivery of the data in our kernel TCP send buffer.
|
||||
// Therefore, the client always receives the RST frame.
|
||||
//
|
||||
// Now what are we going to do after that?
|
||||
//
|
||||
// 1. Naive Option: We just call Close() right after CloseWrite.
|
||||
// This yields the same race condition as explained above (DIF, first
|
||||
// paragraph): The situation just becomae a little more unlikely because
|
||||
// our rstFrameType + CloseWrite dance gave the client a full RTT worth of
|
||||
// time to read the data from its TCP recv buffer.
|
||||
//
|
||||
// 2. Correct Option: Drain the read side until io.EOF
|
||||
// We can read from the unclosed read-side of the connection until we get
|
||||
// the io.EOF caused by the (well behaved) client closing the connection
|
||||
// in response to it reading the rstFrameType frame we sent.
|
||||
// However, this wastes resources on our side (we don't care about the
|
||||
// pending data anymore), and has potential for (D)DoS through CPU-time
|
||||
// exhaustion if the client just keeps sending data.
|
||||
// Then again, this option has the advantage with well-behaved clients
|
||||
// that we do not waste precious kernel-memory on the stale receive buffer
|
||||
// on our side (which is still full of data that we do not intend to read).
|
||||
//
|
||||
// 2.1 DoS Mitigation: Bound the number of bytes to drain, then close
|
||||
// At the time of writing, this technique is practiced by the Go http server
|
||||
// implementation, and actually SHOULDed in the HTTP 1.1 RFC. It is
|
||||
// important to disable the idle timeout of the underlying timeoutconn in
|
||||
// that case and set an absolute deadline by which the socket must have
|
||||
// been fully drained. Not too hard, though ;)
|
||||
//
|
||||
// 2.2: Client sends RST, not FIN when it receives an rstFrameTyp frame.
|
||||
// We can use wire.(*net.TCPConn).SetLinger(0) to force an RST to be sent
|
||||
// on a subsequent close (instead of a FIN + wait for FIN+ACK).
|
||||
// TODO put this into Wire interface as an abstract method.
|
||||
//
|
||||
// 2.3 Only start draining after N*RTT
|
||||
// We have an RTT approximation from Wire.CloseWrite, which by definition
|
||||
// must not return before all to-be-sent-data has been acknowledged by the
|
||||
// client. Give the client a fair chance to react, and only start draining
|
||||
// after a multiple of the RTT has elapsed.
|
||||
// We waste the recv buffer memory a little longer than necessary, iff the
|
||||
// client reacts faster than expected. But we don't wast CPU time.
|
||||
// If we apply 2.2, we'll also have the benefit that our kernel will have
|
||||
// dropped the recv buffer memory as soon as it receives the client's RST.
|
||||
//
|
||||
// 3. TCP-only: OOB-messaging
|
||||
// We can use TCP's 'urgent' flag in the client to acknowledge the receipt
|
||||
// of the rstFrameType to us.
|
||||
// We can thus wait for that signal while leaving the kernel buffer as is.
|
||||
|
||||
// TODO: For now, we just drain the connection (Option 2),
|
||||
// but we enforce deadlines so the _time_ we drain the connection
|
||||
// is bounded, although we do _that_ at full speed
|
||||
|
||||
defer prometheus.NewTimer(prom.ShutdownSeconds).ObserveDuration()
|
||||
|
||||
closeWire := func(step string) error {
|
||||
// TODO SetLinger(0) or similiar (we want RST frames here, not FINS)
|
||||
if closeErr := c.nc.Close(); closeErr != nil {
|
||||
prom.ShutdownCloseErrors.WithLabelValues("close").Inc()
|
||||
return closeErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
hardclose := func(err error, step string) error {
|
||||
prom.ShutdownHardCloses.WithLabelValues(step).Inc()
|
||||
return closeWire(step)
|
||||
}
|
||||
|
||||
c.shutdown.Begin()
|
||||
// new calls to c.ReadFrame and c.WriteFrame will now return ErrShutdown
|
||||
// Aquiring writeMtx and readMtx ensures that the last calls exit successfully
|
||||
|
||||
// disable renewing timeouts now, enforce the requested deadline instead
|
||||
// we need to do this before aquiring locks to enforce the timeout on slow
|
||||
// clients / if something hangs (DoS mitigation)
|
||||
if err := c.nc.DisableTimeouts(); err != nil {
|
||||
return hardclose(err, "disable_timeouts")
|
||||
}
|
||||
if err := c.nc.SetDeadline(deadline); err != nil {
|
||||
return hardclose(err, "set_deadline")
|
||||
}
|
||||
|
||||
c.writeMtx.Lock()
|
||||
defer c.writeMtx.Unlock()
|
||||
|
||||
if err := c.writeFrame([]byte{}, rstFrameType); err != nil {
|
||||
return hardclose(err, "write_frame")
|
||||
}
|
||||
|
||||
if err := c.nc.CloseWrite(); err != nil {
|
||||
return hardclose(err, "close_write")
|
||||
}
|
||||
|
||||
c.readMtx.Lock()
|
||||
defer c.readMtx.Unlock()
|
||||
|
||||
// TODO DoS mitigation: wait for client acknowledgement that they initiated Shutdown,
|
||||
// then perform abortive close on our side. As explained above, probably requires
|
||||
// OOB signaling such as TCP's urgent flag => transport-specific?
|
||||
|
||||
// TODO DoS mitigation by reading limited number of bytes
|
||||
// see discussion above why this is non-trivial
|
||||
defer prometheus.NewTimer(prom.ShutdownDrainSeconds).ObserveDuration()
|
||||
n, _ := io.Copy(ioutil.Discard, c.nc)
|
||||
prom.ShutdownDrainBytesRead.Observe(float64(n))
|
||||
|
||||
return closeWire("close")
|
||||
}
|
63
rpc/dataconn/frameconn/frameconn_prometheus.go
Normal file
63
rpc/dataconn/frameconn/frameconn_prometheus.go
Normal file
@ -0,0 +1,63 @@
|
||||
package frameconn
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
var prom struct {
|
||||
ShutdownDrainBytesRead prometheus.Summary
|
||||
ShutdownSeconds prometheus.Summary
|
||||
ShutdownDrainSeconds prometheus.Summary
|
||||
ShutdownHardCloses *prometheus.CounterVec
|
||||
ShutdownCloseErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func init() {
|
||||
prom.ShutdownDrainBytesRead = prometheus.NewSummary(prometheus.SummaryOpts{
|
||||
Namespace: "zrepl",
|
||||
Subsystem: "frameconn",
|
||||
Name: "shutdown_drain_bytes_read",
|
||||
Help: "Number of bytes read during the drain phase of connection shutdown",
|
||||
})
|
||||
prom.ShutdownSeconds = prometheus.NewSummary(prometheus.SummaryOpts{
|
||||
Namespace: "zrepl",
|
||||
Subsystem: "frameconn",
|
||||
Name: "shutdown_seconds",
|
||||
Help: "Seconds it took for connection shutdown to complete",
|
||||
})
|
||||
prom.ShutdownDrainSeconds = prometheus.NewSummary(prometheus.SummaryOpts{
|
||||
Namespace: "zrepl",
|
||||
Subsystem: "frameconn",
|
||||
Name: "shutdown_drain_seconds",
|
||||
Help: "Seconds it took from read-side-drain until shutdown completion",
|
||||
})
|
||||
prom.ShutdownHardCloses = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "zrepl",
|
||||
Subsystem: "frameconn",
|
||||
Name: "shutdown_hard_closes",
|
||||
Help: "Number of hard connection closes during shutdown (abortive close)",
|
||||
}, []string{"step"})
|
||||
prom.ShutdownCloseErrors = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "zrepl",
|
||||
Subsystem: "frameconn",
|
||||
Name: "shutdown_close_errors",
|
||||
Help: "Number of errors closing the underlying network connection. Should alert on this",
|
||||
}, []string{"step"})
|
||||
}
|
||||
|
||||
func PrometheusRegister(registry prometheus.Registerer) error {
|
||||
if err := registry.Register(prom.ShutdownDrainBytesRead); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registry.Register(prom.ShutdownSeconds); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registry.Register(prom.ShutdownDrainSeconds); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registry.Register(prom.ShutdownHardCloses); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := registry.Register(prom.ShutdownCloseErrors); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
37
rpc/dataconn/frameconn/frameconn_shutdown_fsm.go
Normal file
37
rpc/dataconn/frameconn/frameconn_shutdown_fsm.go
Normal file
@ -0,0 +1,37 @@
|
||||
package frameconn
|
||||
|
||||
import "sync"
|
||||
|
||||
type shutdownFSM struct {
|
||||
mtx sync.Mutex
|
||||
state shutdownFSMState
|
||||
}
|
||||
|
||||
type shutdownFSMState uint32
|
||||
|
||||
const (
|
||||
shutdownStateOpen shutdownFSMState = iota
|
||||
shutdownStateBegin
|
||||
)
|
||||
|
||||
func newShutdownFSM() *shutdownFSM {
|
||||
fsm := &shutdownFSM{
|
||||
state: shutdownStateOpen,
|
||||
}
|
||||
return fsm
|
||||
}
|
||||
|
||||
func (f *shutdownFSM) Begin() (thisCallStartedShutdown bool) {
|
||||
f.mtx.Lock()
|
||||
defer f.mtx.Unlock()
|
||||
thisCallStartedShutdown = f.state != shutdownStateOpen
|
||||
f.state = shutdownStateBegin
|
||||
return thisCallStartedShutdown
|
||||
}
|
||||
|
||||
func (f *shutdownFSM) IsShuttingDown() bool {
|
||||
f.mtx.Lock()
|
||||
defer f.mtx.Unlock()
|
||||
return f.state != shutdownStateOpen
|
||||
}
|
||||
|
22
rpc/dataconn/frameconn/frameconn_test.go
Normal file
22
rpc/dataconn/frameconn/frameconn_test.go
Normal file
@ -0,0 +1,22 @@
|
||||
package frameconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsPublicFrameType(t *testing.T) {
|
||||
for i := uint32(0); i < 256; i++ {
|
||||
i := i
|
||||
t.Run(fmt.Sprintf("^%d", i), func(t *testing.T) {
|
||||
assert.False(t, IsPublicFrameType(^i))
|
||||
})
|
||||
}
|
||||
assert.True(t, IsPublicFrameType(0))
|
||||
assert.True(t, IsPublicFrameType(1))
|
||||
assert.True(t, IsPublicFrameType(255))
|
||||
assert.False(t, IsPublicFrameType(rstFrameType))
|
||||
}
|
||||
|
137
rpc/dataconn/heartbeatconn/heartbeatconn.go
Normal file
137
rpc/dataconn/heartbeatconn/heartbeatconn.go
Normal file
@ -0,0 +1,137 @@
|
||||
package heartbeatconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
state state
|
||||
// if not nil, opErr is returned for ReadFrame and WriteFrame (not for Close, though)
|
||||
opErr atomic.Value // error
|
||||
fc *frameconn.Conn
|
||||
sendInterval, timeout time.Duration
|
||||
stopSend chan struct{}
|
||||
lastFrameSent atomic.Value // time.Time
|
||||
}
|
||||
|
||||
type HeartbeatTimeout struct{}
|
||||
|
||||
func (e HeartbeatTimeout) Error() string {
|
||||
return "heartbeat timeout"
|
||||
}
|
||||
|
||||
func (e HeartbeatTimeout) Temporary() bool { return true }
|
||||
|
||||
func (e HeartbeatTimeout) Timeout() bool { return true }
|
||||
|
||||
var _ net.Error = HeartbeatTimeout{}
|
||||
|
||||
type state = int32
|
||||
|
||||
const (
|
||||
stateInitial state = 0
|
||||
stateClosed state = 2
|
||||
)
|
||||
|
||||
const (
|
||||
heartbeat uint32 = 1 << 24
|
||||
)
|
||||
|
||||
// The 4 MSBs of ft are reserved for frameconn, we reserve the next 4 MSB for us.
|
||||
func IsPublicFrameType(ft uint32) bool {
|
||||
return frameconn.IsPublicFrameType(ft) && (0xf<<24)&ft == 0
|
||||
}
|
||||
|
||||
func assertPublicFrameType(frameType uint32) {
|
||||
if !IsPublicFrameType(frameType) {
|
||||
panic(fmt.Sprintf("heartbeatconn: frame type %v cannot be used by consumers of this package", frameType))
|
||||
}
|
||||
}
|
||||
|
||||
func Wrap(nc timeoutconn.Wire, sendInterval, timeout time.Duration) *Conn {
|
||||
c := &Conn{
|
||||
fc: frameconn.Wrap(timeoutconn.Wrap(nc, timeout)),
|
||||
stopSend: make(chan struct{}),
|
||||
sendInterval: sendInterval,
|
||||
timeout: timeout,
|
||||
}
|
||||
c.lastFrameSent.Store(time.Now())
|
||||
go c.sendHeartbeats()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) Shutdown() error {
|
||||
normalClose := atomic.CompareAndSwapInt32(&c.state, stateInitial, stateClosed)
|
||||
if normalClose {
|
||||
close(c.stopSend)
|
||||
}
|
||||
return c.fc.Shutdown(time.Now().Add(c.timeout))
|
||||
}
|
||||
|
||||
// started as a goroutine in constructor
|
||||
func (c *Conn) sendHeartbeats() {
|
||||
sleepTime := func(now time.Time) time.Duration {
|
||||
lastSend := c.lastFrameSent.Load().(time.Time)
|
||||
return lastSend.Add(c.sendInterval).Sub(now)
|
||||
}
|
||||
timer := time.NewTimer(sleepTime(time.Now()))
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.stopSend:
|
||||
return
|
||||
case now := <-timer.C:
|
||||
func() {
|
||||
defer func() {
|
||||
timer.Reset(sleepTime(time.Now()))
|
||||
}()
|
||||
if sleepTime(now) > 0 {
|
||||
return
|
||||
}
|
||||
debug("send heartbeat")
|
||||
// if the connection is in zombie mode (aka iptables DROP inbetween peers)
|
||||
// this call or one of its successors will block after filling up the kernel tx buffer
|
||||
c.fc.WriteFrame([]byte{}, heartbeat)
|
||||
// ignore errors from WriteFrame to rate-limit SendHeartbeat retries
|
||||
c.lastFrameSent.Store(time.Now())
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrame() (frameconn.Frame, error) {
|
||||
return c.readFrameFiltered()
|
||||
}
|
||||
|
||||
func (c *Conn) readFrameFiltered() (frameconn.Frame, error) {
|
||||
for {
|
||||
f, err := c.fc.ReadFrame()
|
||||
if err != nil {
|
||||
return frameconn.Frame{}, err
|
||||
}
|
||||
if IsPublicFrameType(f.Header.Type) {
|
||||
return f, nil
|
||||
}
|
||||
if f.Header.Type != heartbeat {
|
||||
return frameconn.Frame{}, fmt.Errorf("unknown frame type %x", f.Header.Type)
|
||||
}
|
||||
// drop heartbeat frame
|
||||
debug("received heartbeat")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) WriteFrame(payload []byte, frameType uint32) error {
|
||||
assertPublicFrameType(frameType)
|
||||
err := c.fc.WriteFrame(payload, frameType)
|
||||
if err == nil {
|
||||
c.lastFrameSent.Store(time.Now())
|
||||
}
|
||||
return err
|
||||
}
|
20
rpc/dataconn/heartbeatconn/heartbeatconn_debug.go
Normal file
20
rpc/dataconn/heartbeatconn/heartbeatconn_debug.go
Normal file
@ -0,0 +1,20 @@
|
||||
package heartbeatconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var debugEnabled bool = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("ZREPL_RPC_DATACONN_HEARTBEATCONN_DEBUG") != "" {
|
||||
debugEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func debug(format string, args ...interface{}) {
|
||||
if debugEnabled {
|
||||
fmt.Fprintf(os.Stderr, "rpc/dataconn/heartbeatconn: %s\n", fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
26
rpc/dataconn/heartbeatconn/heartbeatconn_test.go
Normal file
26
rpc/dataconn/heartbeatconn/heartbeatconn_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package heartbeatconn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
|
||||
)
|
||||
|
||||
func TestFrameTypes(t *testing.T) {
|
||||
assert.True(t, frameconn.IsPublicFrameType(heartbeat))
|
||||
}
|
||||
|
||||
func TestNegativeTimer(t *testing.T) {
|
||||
|
||||
timer := time.NewTimer(-1 * time.Second)
|
||||
defer timer.Stop()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Log("timer with negative time fired, that's what we want")
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
}
|
184
rpc/dataconn/microbenchmark/microbenchmark.go
Normal file
184
rpc/dataconn/microbenchmark/microbenchmark.go
Normal file
@ -0,0 +1,184 @@
|
||||
// microbenchmark to manually test rpc/dataconn perforamnce
|
||||
//
|
||||
// With stdin / stdout on client and server, simulating zfs send|recv piping
|
||||
//
|
||||
// ./microbenchmark -appmode server | pv -r > /dev/null
|
||||
// ./microbenchmark -appmode client -direction recv < /dev/zero
|
||||
//
|
||||
//
|
||||
// Without the overhead of pipes (just protocol perforamnce, mostly useful with perf bc no bw measurement)
|
||||
//
|
||||
// ./microbenchmark -appmode client -direction recv -devnoopWriter -devnoopReader
|
||||
// ./microbenchmark -appmode server -devnoopReader -devnoopWriter
|
||||
//
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/profile"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/util/devnoop"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
func orDie(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type readerStreamCopier struct{ io.Reader }
|
||||
|
||||
func (readerStreamCopier) Close() error { return nil }
|
||||
|
||||
type readerStreamCopierErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (readerStreamCopierErr) IsReadError() bool { return false }
|
||||
func (readerStreamCopierErr) IsWriteError() bool { return true }
|
||||
|
||||
func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
var buf [1 << 21]byte
|
||||
_, err := io.CopyBuffer(w, c.Reader, buf[:])
|
||||
// always assume write error
|
||||
return readerStreamCopierErr{err}
|
||||
}
|
||||
|
||||
type devNullHandler struct{}
|
||||
|
||||
func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
var res pdu.SendRes
|
||||
if args.devnoopReader {
|
||||
return &res, readerStreamCopier{devnoop.Get()}, nil
|
||||
} else {
|
||||
return &res, readerStreamCopier{os.Stdin}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
var out io.Writer = os.Stdout
|
||||
if args.devnoopWriter {
|
||||
out = devnoop.Get()
|
||||
}
|
||||
err := stream.WriteStreamTo(out)
|
||||
var res pdu.ReceiveRes
|
||||
return &res, err
|
||||
}
|
||||
|
||||
type tcpConnecter struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (c tcpConnecter) Connect(ctx context.Context) (timeoutconn.Wire, error) {
|
||||
conn, err := net.Dial("tcp", c.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(*net.TCPConn), nil
|
||||
}
|
||||
|
||||
type tcpListener struct {
|
||||
nl *net.TCPListener
|
||||
clientIdent string
|
||||
}
|
||||
|
||||
func (l tcpListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
tcpconn, err := l.nl.AcceptTCP()
|
||||
orDie(err)
|
||||
return transport.NewAuthConn(tcpconn, l.clientIdent), nil
|
||||
}
|
||||
|
||||
func (l tcpListener) Addr() net.Addr { return l.nl.Addr() }
|
||||
|
||||
func (l tcpListener) Close() error { return l.nl.Close() }
|
||||
|
||||
var args struct {
|
||||
addr string
|
||||
appmode string
|
||||
direction string
|
||||
profile bool
|
||||
devnoopReader bool
|
||||
devnoopWriter bool
|
||||
}
|
||||
|
||||
func server() {
|
||||
|
||||
log := logger.NewStderrDebugLogger()
|
||||
log.Debug("starting server")
|
||||
nl, err := net.Listen("tcp", args.addr)
|
||||
orDie(err)
|
||||
l := tcpListener{nl.(*net.TCPListener), "fakeclientidentity"}
|
||||
|
||||
srv := dataconn.NewServer(nil, logger.NewStderrDebugLogger(), devNullHandler{})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
srv.Serve(ctx, l)
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
flag.BoolVar(&args.profile, "profile", false, "")
|
||||
flag.BoolVar(&args.devnoopReader, "devnoopReader", false, "")
|
||||
flag.BoolVar(&args.devnoopWriter, "devnoopWriter", false, "")
|
||||
flag.StringVar(&args.addr, "address", ":8888", "")
|
||||
flag.StringVar(&args.appmode, "appmode", "client|server", "")
|
||||
flag.StringVar(&args.direction, "direction", "", "send|recv")
|
||||
flag.Parse()
|
||||
|
||||
if args.profile {
|
||||
defer profile.Start(profile.CPUProfile).Stop()
|
||||
}
|
||||
|
||||
switch args.appmode {
|
||||
case "client":
|
||||
client()
|
||||
case "server":
|
||||
server()
|
||||
default:
|
||||
orDie(fmt.Errorf("unknown appmode %q", args.appmode))
|
||||
}
|
||||
}
|
||||
|
||||
func client() {
|
||||
|
||||
logger := logger.NewStderrDebugLogger()
|
||||
ctx := context.Background()
|
||||
|
||||
connecter := tcpConnecter{args.addr}
|
||||
client := dataconn.NewClient(connecter, logger)
|
||||
|
||||
switch args.direction {
|
||||
case "send":
|
||||
req := pdu.SendReq{}
|
||||
_, stream, err := client.ReqSend(ctx, &req)
|
||||
orDie(err)
|
||||
err = stream.WriteStreamTo(os.Stdout)
|
||||
orDie(err)
|
||||
case "recv":
|
||||
var r io.Reader = os.Stdin
|
||||
if args.devnoopReader {
|
||||
r = devnoop.Get()
|
||||
}
|
||||
s := readerStreamCopier{r}
|
||||
req := pdu.ReceiveReq{}
|
||||
_, err := client.ReqRecv(ctx, &req, &s)
|
||||
orDie(err)
|
||||
default:
|
||||
orDie(fmt.Errorf("unknown direction%q", args.direction))
|
||||
}
|
||||
|
||||
}
|
269
rpc/dataconn/stream/stream.go
Normal file
269
rpc/dataconn/stream/stream.go
Normal file
@ -0,0 +1,269 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyLogger contextKey = 1 + iota
|
||||
)
|
||||
|
||||
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLogger, log)
|
||||
}
|
||||
|
||||
func getLog(ctx context.Context) Logger {
|
||||
log, ok := ctx.Value(contextKeyLogger).(Logger)
|
||||
if !ok {
|
||||
log = logger.NewNullLogger()
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
// Frame types used by this package.
|
||||
// 4 MSBs are reserved for frameconn, next 4 MSB for heartbeatconn, next 4 MSB for us.
|
||||
const (
|
||||
StreamErrTrailer uint32 = 1 << (16 + iota)
|
||||
End
|
||||
// max 16
|
||||
)
|
||||
|
||||
// NOTE: make sure to add a tests for each frame type that checks
|
||||
// whether it is heartbeatconn.IsPublicFrameType()
|
||||
|
||||
// Check whether the given frame type is allowed to be used by
|
||||
// consumers of this package. Intended for use in unit tests.
|
||||
func IsPublicFrameType(ft uint32) bool {
|
||||
return frameconn.IsPublicFrameType(ft) && heartbeatconn.IsPublicFrameType(ft) && ((0xf<<16)&ft == 0)
|
||||
}
|
||||
|
||||
const FramePayloadShift = 19
|
||||
|
||||
var bufpool = base2bufpool.New(FramePayloadShift, FramePayloadShift, base2bufpool.Panic)
|
||||
|
||||
// if sendStream returns an error, that error will be sent as a trailer to the client
|
||||
// ok will return nil, though.
|
||||
func writeStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) {
|
||||
debug("writeStream: enter stype=%v", stype)
|
||||
defer debug("writeStream: return")
|
||||
if stype == 0 {
|
||||
panic("stype must be non-zero")
|
||||
}
|
||||
if !IsPublicFrameType(stype) {
|
||||
panic(fmt.Sprintf("stype %v is not public", stype))
|
||||
}
|
||||
return doWriteStream(ctx, c, stream, stype)
|
||||
}
|
||||
|
||||
func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, stype uint32) (errStream, errConn error) {
|
||||
|
||||
// RULE1 (buf == <zero>) XOR (err == nil)
|
||||
type read struct {
|
||||
buf base2bufpool.Buffer
|
||||
err error
|
||||
}
|
||||
|
||||
reads := make(chan read, 5)
|
||||
go func() {
|
||||
for {
|
||||
buffer := bufpool.Get(1 << FramePayloadShift)
|
||||
bufferBytes := buffer.Bytes()
|
||||
n, err := io.ReadFull(stream, bufferBytes)
|
||||
buffer.Shrink(uint(n))
|
||||
// if we received anything, send one read without an error (RULE 1)
|
||||
if n > 0 {
|
||||
reads <- read{buffer, nil}
|
||||
}
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
// happens iff io.ReadFull read io.EOF from stream
|
||||
err = io.EOF
|
||||
}
|
||||
if err != nil {
|
||||
reads <- read{err: err} // RULE1
|
||||
close(reads)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for read := range reads {
|
||||
if read.err == nil {
|
||||
// RULE 1: read.buf is valid
|
||||
// next line is the hot path...
|
||||
writeErr := c.WriteFrame(read.buf.Bytes(), stype)
|
||||
read.buf.Free()
|
||||
if writeErr != nil {
|
||||
return nil, writeErr
|
||||
}
|
||||
continue
|
||||
} else if read.err == io.EOF {
|
||||
if err := c.WriteFrame([]byte{}, End); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
} else {
|
||||
errReader := strings.NewReader(read.err.Error())
|
||||
errReadErrReader, errConnWrite := doWriteStream(ctx, c, errReader, StreamErrTrailer)
|
||||
if errReadErrReader != nil {
|
||||
panic(errReadErrReader) // in-memory, cannot happen
|
||||
}
|
||||
return read.err, errConnWrite
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type ReadStreamErrorKind int
|
||||
|
||||
const (
|
||||
ReadStreamErrorKindConn ReadStreamErrorKind = 1 + iota
|
||||
ReadStreamErrorKindWrite
|
||||
ReadStreamErrorKindSource
|
||||
ReadStreamErrorKindStreamErrTrailerEncoding
|
||||
ReadStreamErrorKindUnexpectedFrameType
|
||||
)
|
||||
|
||||
type ReadStreamError struct {
|
||||
Kind ReadStreamErrorKind
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ReadStreamError) Error() string {
|
||||
kindStr := ""
|
||||
switch e.Kind {
|
||||
case ReadStreamErrorKindConn:
|
||||
kindStr = " read error: "
|
||||
case ReadStreamErrorKindWrite:
|
||||
kindStr = " write error: "
|
||||
case ReadStreamErrorKindSource:
|
||||
kindStr = " source error: "
|
||||
case ReadStreamErrorKindStreamErrTrailerEncoding:
|
||||
kindStr = " source implementation error: "
|
||||
case ReadStreamErrorKindUnexpectedFrameType:
|
||||
kindStr = " protocol error: "
|
||||
}
|
||||
return fmt.Sprintf("stream:%s%s", kindStr, e.Err)
|
||||
}
|
||||
|
||||
var _ net.Error = &ReadStreamError{}
|
||||
|
||||
func (e ReadStreamError) netErr() net.Error {
|
||||
if netErr, ok := e.Err.(net.Error); ok {
|
||||
return netErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e ReadStreamError) Timeout() bool {
|
||||
if netErr := e.netErr(); netErr != nil {
|
||||
return netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e ReadStreamError) Temporary() bool {
|
||||
if netErr := e.netErr(); netErr != nil {
|
||||
return netErr.Temporary()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ zfs.StreamCopierError = &ReadStreamError{}
|
||||
|
||||
func (e ReadStreamError) IsReadError() bool {
|
||||
return e.Kind != ReadStreamErrorKindWrite
|
||||
}
|
||||
|
||||
func (e ReadStreamError) IsWriteError() bool {
|
||||
return e.Kind == ReadStreamErrorKindWrite
|
||||
}
|
||||
|
||||
type readFrameResult struct {
|
||||
f frameconn.Frame
|
||||
err error
|
||||
}
|
||||
|
||||
func readFrames(reads chan<- readFrameResult, c *heartbeatconn.Conn) {
|
||||
for {
|
||||
var r readFrameResult
|
||||
r.f, r.err = c.ReadFrame()
|
||||
reads <- r
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStream will close c if an error reading from c or writing to receiver occurs
|
||||
//
|
||||
// readStream calls itself recursively to read multi-frame error trailers
|
||||
// Thus, the reads channel needs to be a parameter.
|
||||
func readStream(reads <-chan readFrameResult, c *heartbeatconn.Conn, receiver io.Writer, stype uint32) *ReadStreamError {
|
||||
|
||||
var f frameconn.Frame
|
||||
for read := range reads {
|
||||
debug("readStream: read frame %v %v", read.f.Header, read.err)
|
||||
f = read.f
|
||||
if read.err != nil {
|
||||
return &ReadStreamError{ReadStreamErrorKindConn, read.err}
|
||||
}
|
||||
if f.Header.Type != stype {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := receiver.Write(f.Buffer.Bytes())
|
||||
if err != nil {
|
||||
f.Buffer.Free()
|
||||
return &ReadStreamError{ReadStreamErrorKindWrite, err} // FIXME wrap as writer error
|
||||
}
|
||||
if n != len(f.Buffer.Bytes()) {
|
||||
f.Buffer.Free()
|
||||
return &ReadStreamError{ReadStreamErrorKindWrite, io.ErrShortWrite}
|
||||
}
|
||||
f.Buffer.Free()
|
||||
}
|
||||
|
||||
if f.Header.Type == End {
|
||||
debug("readStream: End reached")
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.Header.Type == StreamErrTrailer {
|
||||
debug("readStream: begin of StreamErrTrailer")
|
||||
var errBuf bytes.Buffer
|
||||
if n, err := errBuf.Write(f.Buffer.Bytes()); n != len(f.Buffer.Bytes()) || err != nil {
|
||||
panic(fmt.Sprintf("unexpected bytes.Buffer write error: %v %v", n, err))
|
||||
}
|
||||
// recursion ftw! we won't enter this if stmt because stype == StreamErrTrailer in the following call
|
||||
rserr := readStream(reads, c, &errBuf, StreamErrTrailer)
|
||||
if rserr != nil && rserr.Kind == ReadStreamErrorKindWrite {
|
||||
panic(fmt.Sprintf("unexpected bytes.Buffer write error: %s", rserr))
|
||||
} else if rserr != nil {
|
||||
debug("readStream: rserr != nil && != ReadStreamErrorKindWrite: %v %v\n", rserr.Kind, rserr.Err)
|
||||
return rserr
|
||||
}
|
||||
if !utf8.Valid(errBuf.Bytes()) {
|
||||
return &ReadStreamError{ReadStreamErrorKindStreamErrTrailerEncoding, fmt.Errorf("source error, but not encoded as UTF-8")}
|
||||
}
|
||||
return &ReadStreamError{ReadStreamErrorKindSource, fmt.Errorf("%s", errBuf.String())}
|
||||
}
|
||||
|
||||
return &ReadStreamError{ReadStreamErrorKindUnexpectedFrameType, fmt.Errorf("unexpected frame type %v (expected %v)", f.Header.Type, stype)}
|
||||
}
|
194
rpc/dataconn/stream/stream_conn.go
Normal file
194
rpc/dataconn/stream/stream_conn.go
Normal file
@ -0,0 +1,194 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
hc *heartbeatconn.Conn
|
||||
|
||||
// whether the per-conn readFrames goroutine completed
|
||||
waitReadFramesDone chan struct{}
|
||||
// filled by per-conn readFrames goroutine
|
||||
frameReads chan readFrameResult
|
||||
|
||||
// readMtx serializes read stream operations because we inherently only
|
||||
// support a single stream at a time over hc.
|
||||
readMtx sync.Mutex
|
||||
readClean bool
|
||||
allowWriteStreamTo bool
|
||||
|
||||
// writeMtx serializes write stream operations because we inherently only
|
||||
// support a single stream at a time over hc.
|
||||
writeMtx sync.Mutex
|
||||
writeClean bool
|
||||
}
|
||||
|
||||
var readMessageSentinel = fmt.Errorf("read stream complete")
|
||||
|
||||
type writeStreamToErrorUnknownState struct{}
|
||||
|
||||
func (e writeStreamToErrorUnknownState) Error() string {
|
||||
return "dataconn read stream: connection is in unknown state"
|
||||
}
|
||||
|
||||
func (e writeStreamToErrorUnknownState) IsReadError() bool { return true }
|
||||
|
||||
func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false }
|
||||
|
||||
func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn {
|
||||
hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout)
|
||||
conn := &Conn{
|
||||
hc: hc, readClean: true, writeClean: true,
|
||||
waitReadFramesDone: make(chan struct{}),
|
||||
frameReads: make(chan readFrameResult, 5), // FIXME constant
|
||||
}
|
||||
go conn.readFrames()
|
||||
return conn
|
||||
}
|
||||
|
||||
func isConnCleanAfterRead(res *ReadStreamError) bool {
|
||||
return res == nil || res.Kind == ReadStreamErrorKindSource || res.Kind == ReadStreamErrorKindStreamErrTrailerEncoding
|
||||
}
|
||||
|
||||
func isConnCleanAfterWrite(err error) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped")
|
||||
|
||||
func (c *Conn) readFrames() {
|
||||
defer close(c.waitReadFramesDone)
|
||||
defer close(c.frameReads)
|
||||
readFrames(c.frameReads, c.hc)
|
||||
}
|
||||
|
||||
func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) {
|
||||
c.readMtx.Lock()
|
||||
defer c.readMtx.Unlock()
|
||||
if !c.readClean {
|
||||
return nil, &ReadStreamError{
|
||||
Kind: ReadStreamErrorKindConn,
|
||||
Err: fmt.Errorf("dataconn read message: connection is in unknown state"),
|
||||
}
|
||||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
var buf bytes.Buffer
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lr := io.LimitReader(r, int64(maxSize))
|
||||
if _, err := io.Copy(&buf, lr); err != nil && err != readMessageSentinel {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
err := readStream(c.frameReads, c.hc, w, frameType)
|
||||
c.readClean = isConnCleanAfterRead(err)
|
||||
w.CloseWithError(readMessageSentinel)
|
||||
wg.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteStreamTo reads a stream from Conn and writes it to w.
|
||||
func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError {
|
||||
c.readMtx.Lock()
|
||||
defer c.readMtx.Unlock()
|
||||
if !c.readClean {
|
||||
return writeStreamToErrorUnknownState{}
|
||||
}
|
||||
var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
|
||||
c.readClean = isConnCleanAfterRead(err)
|
||||
|
||||
// https://golang.org/doc/faq#nil_error
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error {
|
||||
c.writeMtx.Lock()
|
||||
defer c.writeMtx.Unlock()
|
||||
if !c.writeClean {
|
||||
return fmt.Errorf("dataconn write message: connection is in unknown state")
|
||||
}
|
||||
errBuf, errConn := writeStream(ctx, c.hc, buf, frameType)
|
||||
if errBuf != nil {
|
||||
panic(errBuf)
|
||||
}
|
||||
c.writeClean = isConnCleanAfterWrite(errConn)
|
||||
return errConn
|
||||
}
|
||||
|
||||
func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error {
|
||||
c.writeMtx.Lock()
|
||||
defer c.writeMtx.Unlock()
|
||||
if !c.writeClean {
|
||||
return fmt.Errorf("dataconn send stream: connection is in unknown state")
|
||||
}
|
||||
|
||||
// avoid io.Pipe if zfs.StreamCopier is an io.Reader
|
||||
var r io.Reader
|
||||
var w *io.PipeWriter
|
||||
streamCopierErrChan := make(chan zfs.StreamCopierError, 1)
|
||||
if reader, ok := src.(io.Reader); ok {
|
||||
r = reader
|
||||
streamCopierErrChan <- nil
|
||||
close(streamCopierErrChan)
|
||||
} else {
|
||||
r, w = io.Pipe()
|
||||
go func() {
|
||||
streamCopierErrChan <- src.WriteStreamTo(w)
|
||||
w.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
type writeStreamRes struct {
|
||||
errStream, errConn error
|
||||
}
|
||||
writeStreamErrChan := make(chan writeStreamRes, 1)
|
||||
go func() {
|
||||
var res writeStreamRes
|
||||
res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType)
|
||||
if w != nil {
|
||||
w.CloseWithError(res.errStream)
|
||||
}
|
||||
writeStreamErrChan <- res
|
||||
}()
|
||||
|
||||
writeRes := <-writeStreamErrChan
|
||||
streamCopierErr := <-streamCopierErrChan
|
||||
c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct?
|
||||
if streamCopierErr != nil && streamCopierErr.IsReadError() {
|
||||
return streamCopierErr // something on our side is bad
|
||||
} else {
|
||||
if writeRes.errStream != nil {
|
||||
return writeRes.errStream
|
||||
} else if writeRes.errConn != nil {
|
||||
return writeRes.errConn
|
||||
}
|
||||
// TODO combined error?
|
||||
return streamCopierErr
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
err := c.hc.Shutdown()
|
||||
<-c.waitReadFramesDone
|
||||
return err
|
||||
}
|
20
rpc/dataconn/stream/stream_debug.go
Normal file
20
rpc/dataconn/stream/stream_debug.go
Normal file
@ -0,0 +1,20 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var debugEnabled bool = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("ZREPL_RPC_DATACONN_STREAM_DEBUG") != "" {
|
||||
debugEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func debug(format string, args ...interface{}) {
|
||||
if debugEnabled {
|
||||
fmt.Fprintf(os.Stderr, "rpc/dataconn/stream: %s\n", fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
131
rpc/dataconn/stream/stream_test.go
Normal file
131
rpc/dataconn/stream/stream_test.go
Normal file
@ -0,0 +1,131 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||||
"github.com/zrepl/zrepl/util/socketpair"
|
||||
)
|
||||
|
||||
func TestFrameTypesOk(t *testing.T) {
|
||||
t.Logf("%v", End)
|
||||
assert.True(t, heartbeatconn.IsPublicFrameType(End))
|
||||
assert.True(t, heartbeatconn.IsPublicFrameType(StreamErrTrailer))
|
||||
}
|
||||
|
||||
func TestStreamer(t *testing.T) {
|
||||
|
||||
anc, bnc, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
hto := 1 * time.Hour
|
||||
a := heartbeatconn.Wrap(anc, hto, hto)
|
||||
b := heartbeatconn.Wrap(bnc, hto, hto)
|
||||
|
||||
log := logger.NewStderrDebugLogger()
|
||||
ctx := WithLogger(context.Background(), log)
|
||||
|
||||
stype := uint32(0x23)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var buf bytes.Buffer
|
||||
buf.Write(
|
||||
bytes.Repeat([]byte{1, 2}, 1<<25),
|
||||
)
|
||||
writeStream(ctx, a, &buf, stype)
|
||||
log.Debug("WriteStream returned")
|
||||
a.Shutdown()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var buf bytes.Buffer
|
||||
ch := make(chan readFrameResult, 5)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
readFrames(ch, b)
|
||||
}()
|
||||
err := readStream(ch, b, &buf, stype)
|
||||
log.WithField("errType", fmt.Sprintf("%T %v", err, err)).Debug("ReadStream returned")
|
||||
assert.Nil(t, err)
|
||||
expected := bytes.Repeat([]byte{1, 2}, 1<<25)
|
||||
assert.True(t, bytes.Equal(expected, buf.Bytes()))
|
||||
b.Shutdown()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
type errReader struct {
|
||||
t *testing.T
|
||||
readErr error
|
||||
}
|
||||
|
||||
func (er errReader) Read(p []byte) (n int, err error) {
|
||||
er.t.Logf("errReader.Read called")
|
||||
return 0, er.readErr
|
||||
}
|
||||
|
||||
func TestMultiFrameStreamErrTraileror(t *testing.T) {
|
||||
anc, bnc, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
hto := 1 * time.Hour
|
||||
a := heartbeatconn.Wrap(anc, hto, hto)
|
||||
b := heartbeatconn.Wrap(bnc, hto, hto)
|
||||
|
||||
log := logger.NewStderrDebugLogger()
|
||||
ctx := WithLogger(context.Background(), log)
|
||||
|
||||
longErr := fmt.Errorf("an error that definitley spans more than one frame:\n%s", strings.Repeat("a\n", 1<<4))
|
||||
|
||||
stype := uint32(0x23)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r := errReader{t, longErr}
|
||||
writeStream(ctx, a, &r, stype)
|
||||
a.Shutdown()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer b.Shutdown()
|
||||
var buf bytes.Buffer
|
||||
ch := make(chan readFrameResult, 5)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
readFrames(ch, b)
|
||||
}()
|
||||
err := readStream(ch, b, &buf, stype)
|
||||
t.Logf("%s", err)
|
||||
require.NotNil(t, err)
|
||||
assert.True(t, buf.Len() == 0)
|
||||
assert.Equal(t, err.Kind, ReadStreamErrorKindSource)
|
||||
receivedErr := err.Err.Error()
|
||||
expectedErr := longErr.Error()
|
||||
assert.True(t, receivedErr == expectedErr) // builtin Equals is too slow
|
||||
if receivedErr != expectedErr {
|
||||
t.Logf("lengths: %v %v", len(receivedErr), len(expectedErr))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
12
rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore
vendored
Normal file
12
rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/.gitignore
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
# setup-specific
|
||||
inventory
|
||||
*.retry
|
||||
|
||||
# generated by gen_files.sh
|
||||
files/*ssh_client_identity
|
||||
files/*ssh_client_identity.pub
|
||||
files/*.tls.*.key
|
||||
files/*.tls.*.csr
|
||||
files/*.tls.*.crt
|
||||
files/wireevaluator
|
||||
|
@ -0,0 +1,15 @@
|
||||
This directory contains very hacky test automation for wireevaluator based on nested Ansible playbooks.
|
||||
|
||||
* Copy `inventory.example` to `inventory`
|
||||
* Adjust `inventory` IP addresses as needed
|
||||
* Make sure there's an OpenSSH server running on the serve host
|
||||
* Make sure there's no firewalling whatsoever between the hosts
|
||||
* Run `GENKEYS=1 ./gen_files.sh` to re-generate self-signed TLS certs
|
||||
* Run the following command, adjusting the `wireevaluator_repeat` value to the number of times you want to repeat each test
|
||||
|
||||
```
|
||||
ansible-playbook -i inventory all.yml -e `wireevaluator_repeat=3`
|
||||
```
|
||||
|
||||
Generally, things are fine if the playbook doesn't show any panics from wireevaluator.
|
||||
|
@ -0,0 +1,17 @@
|
||||
- hosts: connect,serve
|
||||
tasks:
|
||||
|
||||
- name: "run test"
|
||||
include: internal_prepare_and_run_repeated.yml
|
||||
wireevaluator_transport: "{{config.0}}"
|
||||
wireevaluator_case: "{{config.1}}"
|
||||
wireevaluator_repeat: "{{wireevaluator_repeat}}"
|
||||
with_cartesian:
|
||||
- [ tls, ssh, tcp ]
|
||||
-
|
||||
- closewrite_server
|
||||
- closewrite_client
|
||||
- readdeadline_server
|
||||
- readdeadline_client
|
||||
loop_control:
|
||||
loop_var: config
|
55
rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh
Executable file
55
rpc/dataconn/timeoutconn/internal/wireevaluator/testbed/gen_files.sh
Executable file
@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
cd "$( dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
FILESDIR="$(pwd)"/files
|
||||
|
||||
echo "[INFO] compile binary"
|
||||
pushd .. >/dev/null
|
||||
go build -o $FILESDIR/wireevaluator
|
||||
popd >/dev/null
|
||||
|
||||
if [ "$GENKEYS" == "" ]; then
|
||||
echo "[INFO] GENKEYS environment variable not set, assumed to be valid"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "[INFO] gen ssh key"
|
||||
ssh-keygen -f "$FILESDIR/wireevaluator.ssh_client_identity" -t ed25519
|
||||
|
||||
echo "[INFO] gen tls keys"
|
||||
|
||||
cakey="$FILESDIR/wireevaluator.tls.ca.key"
|
||||
cacrt="$FILESDIR/wireevaluator.tls.ca.crt"
|
||||
hostprefix="$FILESDIR/wireevaluator.tls"
|
||||
|
||||
openssl genrsa -out "$cakey" 4096
|
||||
openssl req -x509 -new -nodes -key "$cakey" -sha256 -days 1 -out "$cacrt"
|
||||
|
||||
declare -a HOSTS
|
||||
HOSTS+=("theserver")
|
||||
HOSTS+=("theclient")
|
||||
|
||||
for host in "${HOSTS[@]}"; do
|
||||
key="${hostprefix}.${host}.key"
|
||||
csr="${hostprefix}.${host}.csr"
|
||||
crt="${hostprefix}.${host}.crt"
|
||||
openssl genrsa -out "$key" 2048
|
||||
|
||||
(
|
||||
echo "."
|
||||
echo "."
|
||||
echo "."
|
||||
echo "."
|
||||
echo "."
|
||||
echo $host
|
||||
echo "."
|
||||
echo "."
|
||||
echo "."
|
||||
echo "."
|
||||
) | openssl req -new -key "$key" -out "$csr"
|
||||
|
||||
openssl x509 -req -in "$csr" -CA "$cacrt" -CAkey "$cakey" -CAcreateserial -out "$crt" -days 1 -sha256
|
||||
|
||||
done
|
@ -0,0 +1,54 @@
|
||||
---
|
||||
|
||||
- name: compile binary and any key files required
|
||||
local_action: command ./gen_files.sh
|
||||
|
||||
- name: Kill test binary
|
||||
shell: "killall -9 wireevaluator || true"
|
||||
- name: Deploy new binary
|
||||
copy:
|
||||
src: "files/wireevaluator"
|
||||
dest: "/opt/wireevaluator"
|
||||
mode: 0755
|
||||
|
||||
- set_fact:
|
||||
wireevaluator_connect_ip: "{{hostvars['connect'].ansible_host}}"
|
||||
wireevaluator_serve_ip: "{{hostvars['serve'].ansible_host}}"
|
||||
|
||||
- name: Deploy config
|
||||
template:
|
||||
src: "templates/{{wireevaluator_transport}}.yml.j2"
|
||||
dest: "/opt/wireevaluator.yml"
|
||||
|
||||
- name: Deploy client identity
|
||||
copy:
|
||||
src: "files/wireevaluator.{{item}}"
|
||||
dest: "/opt/wireevaluator.{{item}}"
|
||||
mode: 0400
|
||||
with_items:
|
||||
- ssh_client_identity
|
||||
- ssh_client_identity.pub
|
||||
- tls.ca.key
|
||||
- tls.ca.crt
|
||||
- tls.theserver.key
|
||||
- tls.theserver.crt
|
||||
- tls.theclient.key
|
||||
- tls.theclient.crt
|
||||
|
||||
- name: Setup server ssh client identity access
|
||||
when: inventory_hostname == "serve"
|
||||
block:
|
||||
- authorized_key:
|
||||
user: root
|
||||
state: present
|
||||
key: "{{ lookup('file', 'files/wireevaluator.ssh_client_identity.pub') }}"
|
||||
key_options: 'command="/opt/wireevaluator -mode stdinserver -config /opt/wireevaluator.yml client1"'
|
||||
- file:
|
||||
state: directory
|
||||
mode: 0700
|
||||
path: /tmp/wireevaluator_stdinserver
|
||||
|
||||
- name: repeated test
|
||||
include: internal_run_test_prepared_single.yml
|
||||
with_sequence: start=1 end={{wireevaluator_repeat}}
|
||||
|
@ -0,0 +1,38 @@
|
||||
---
|
||||
|
||||
- debug:
|
||||
msg: "run test transport={{wireevaluator_transport}} case={{wireevaluator_case}} repeatedly"
|
||||
|
||||
- name: Run Server
|
||||
when: inventory_hostname == "serve"
|
||||
command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode serve -testcase {{wireevaluator_case}}
|
||||
register: spawn_servers
|
||||
async: 60
|
||||
poll: 0
|
||||
|
||||
- name: Run Client
|
||||
when: inventory_hostname == "connect"
|
||||
command: /opt/wireevaluator -config /opt/wireevaluator.yml -mode connect -testcase {{wireevaluator_case}}
|
||||
register: spawn_clients
|
||||
async: 60
|
||||
poll: 0
|
||||
|
||||
- name: Wait for server shutdown
|
||||
when: inventory_hostname == "serve"
|
||||
async_status:
|
||||
jid: "{{ spawn_servers.ansible_job_id}}"
|
||||
delay: 0.5
|
||||
retries: 10
|
||||
|
||||
- name: Wait for client shutdown
|
||||
when: inventory_hostname == "connect"
|
||||
async_status:
|
||||
jid: "{{ spawn_clients.ansible_job_id}}"
|
||||
delay: 0.5
|
||||
retries: 10
|
||||
|
||||
- name: Wait for connections to die (TIME_WAIT conns)
|
||||
command: sleep 4
|
||||
changed_when: false
|
||||
|
||||
|
@ -0,0 +1,2 @@
|
||||
connect ansible_user=root ansible_host=192.168.122.128 wireevaluator_mode="connect"
|
||||
serve ansible_user=root ansible_host=192.168.122.129 wireevaluator_mode="serve"
|
@ -0,0 +1,13 @@
|
||||
connect:
|
||||
type: ssh+stdinserver
|
||||
host: {{wireevaluator_serve_ip}}
|
||||
user: root
|
||||
port: 22
|
||||
identity_file: /opt/wireevaluator.ssh_client_identity
|
||||
options: # optional, default [], `-o` arguments passed to ssh
|
||||
- "Compression=yes"
|
||||
serve:
|
||||
type: stdinserver
|
||||
client_identities:
|
||||
- "client1"
|
||||
|
@ -0,0 +1,10 @@
|
||||
connect:
|
||||
type: tcp
|
||||
address: "{{wireevaluator_serve_ip}}:8888"
|
||||
serve:
|
||||
type: tcp
|
||||
listen: ":8888"
|
||||
clients: {
|
||||
"{{wireevaluator_connect_ip}}" : "client1"
|
||||
}
|
||||
|
@ -0,0 +1,16 @@
|
||||
connect:
|
||||
type: tls
|
||||
address: "{{wireevaluator_serve_ip}}:8888"
|
||||
ca: "/opt/wireevaluator.tls.ca.crt"
|
||||
cert: "/opt/wireevaluator.tls.theclient.crt"
|
||||
key: "/opt/wireevaluator.tls.theclient.key"
|
||||
server_cn: "theserver"
|
||||
|
||||
serve:
|
||||
type: tls
|
||||
listen: ":8888"
|
||||
ca: "/opt/wireevaluator.tls.ca.crt"
|
||||
cert: "/opt/wireevaluator.tls.theserver.crt"
|
||||
key: "/opt/wireevaluator.tls.theserver.key"
|
||||
client_cns:
|
||||
- "theclient"
|
111
rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go
Normal file
111
rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go
Normal file
@ -0,0 +1,111 @@
|
||||
// a tool to test whether a given transport implements the timeoutconn.Wire interface
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
netssh "github.com/problame/go-netssh"
|
||||
"github.com/zrepl/yaml-config"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
transportconfig "github.com/zrepl/zrepl/transport/fromconfig"
|
||||
)
|
||||
|
||||
func noerror(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Connect config.ConnectEnum
|
||||
Serve config.ServeEnum
|
||||
}
|
||||
|
||||
var args struct {
|
||||
mode string
|
||||
configPath string
|
||||
testCase string
|
||||
}
|
||||
|
||||
var conf Config
|
||||
|
||||
type TestCase interface {
|
||||
Client(wire transport.Wire)
|
||||
Server(wire transport.Wire)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&args.mode, "mode", "", "connect|serve")
|
||||
flag.StringVar(&args.configPath, "config", "", "config file path")
|
||||
flag.StringVar(&args.testCase, "testcase", "", "")
|
||||
flag.Parse()
|
||||
|
||||
bytes, err := ioutil.ReadFile(args.configPath)
|
||||
noerror(err)
|
||||
err = yaml.UnmarshalStrict(bytes, &conf)
|
||||
noerror(err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
global := &config.Global{
|
||||
Serve: &config.GlobalServe{
|
||||
StdinServer: &config.GlobalStdinServer{
|
||||
SockDir: "/tmp/wireevaluator_stdinserver",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
switch args.mode {
|
||||
case "connect":
|
||||
tc, err := getTestCase(args.testCase)
|
||||
noerror(err)
|
||||
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect)
|
||||
noerror(err)
|
||||
wire, err := connecter.Connect(ctx)
|
||||
noerror(err)
|
||||
tc.Client(wire)
|
||||
case "serve":
|
||||
tc, err := getTestCase(args.testCase)
|
||||
noerror(err)
|
||||
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve)
|
||||
noerror(err)
|
||||
l, err := lf()
|
||||
noerror(err)
|
||||
conn, err := l.Accept(ctx)
|
||||
noerror(err)
|
||||
tc.Server(conn)
|
||||
case "stdinserver":
|
||||
identity := flag.Arg(0)
|
||||
unixaddr := path.Join(global.Serve.StdinServer.SockDir, identity)
|
||||
err := netssh.Proxy(ctx, unixaddr)
|
||||
if err == nil {
|
||||
os.Exit(0)
|
||||
}
|
||||
panic(err)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown mode %q", args.mode))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getTestCase(tcName string) (TestCase, error) {
|
||||
switch tcName {
|
||||
case "closewrite_server":
|
||||
return &CloseWrite{mode: CloseWriteServerSide}, nil
|
||||
case "closewrite_client":
|
||||
return &CloseWrite{mode: CloseWriteClientSide}, nil
|
||||
case "readdeadline_client":
|
||||
return &Deadlines{mode: DeadlineModeClientTimeout}, nil
|
||||
case "readdeadline_server":
|
||||
return &Deadlines{mode: DeadlineModeServerTimeout}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown test case %q", tcName)
|
||||
}
|
||||
}
|
@ -0,0 +1,110 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type CloseWriteMode uint
|
||||
|
||||
const (
|
||||
CloseWriteClientSide CloseWriteMode = 1 + iota
|
||||
CloseWriteServerSide
|
||||
)
|
||||
|
||||
type CloseWrite struct {
|
||||
mode CloseWriteMode
|
||||
}
|
||||
|
||||
// sent repeatedly
|
||||
var closeWriteTestSendData = bytes.Repeat([]byte{0x23, 0x42}, 1<<24)
|
||||
var closeWriteErrorMsg = []byte{0xb, 0xa, 0xd, 0xf, 0x0, 0x0, 0xd}
|
||||
|
||||
func (m CloseWrite) Client(wire transport.Wire) {
|
||||
switch m.mode {
|
||||
case CloseWriteClientSide:
|
||||
m.receiver(wire)
|
||||
case CloseWriteServerSide:
|
||||
m.sender(wire)
|
||||
default:
|
||||
panic(m.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (m CloseWrite) Server(wire transport.Wire) {
|
||||
switch m.mode {
|
||||
case CloseWriteClientSide:
|
||||
m.sender(wire)
|
||||
case CloseWriteServerSide:
|
||||
m.receiver(wire)
|
||||
default:
|
||||
panic(m.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (CloseWrite) sender(wire transport.Wire) {
|
||||
defer func() {
|
||||
closeErr := wire.Close()
|
||||
log.Printf("closeErr=%T %s", closeErr, closeErr)
|
||||
}()
|
||||
|
||||
type opResult struct {
|
||||
err error
|
||||
}
|
||||
writeDone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
close(writeDone)
|
||||
for {
|
||||
_, err := wire.Write(closeWriteTestSendData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
<-writeDone
|
||||
}()
|
||||
|
||||
var respBuf bytes.Buffer
|
||||
_, err := io.Copy(&respBuf, wire)
|
||||
if err != nil {
|
||||
log.Fatalf("should have received io.EOF, which is masked by io.Copy, got: %s", err)
|
||||
}
|
||||
if !bytes.Equal(respBuf.Bytes(), closeWriteErrorMsg) {
|
||||
log.Fatalf("did not receive error message, got response with len %v:\n%v", respBuf.Len(), respBuf.Bytes())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (CloseWrite) receiver(wire transport.Wire) {
|
||||
|
||||
// consume half the test data, then detect an error, send it and CloseWrite
|
||||
|
||||
r := io.LimitReader(wire, int64(5 * len(closeWriteTestSendData)/3))
|
||||
_, err := io.Copy(ioutil.Discard, r)
|
||||
noerror(err)
|
||||
|
||||
var errBuf bytes.Buffer
|
||||
errBuf.Write(closeWriteErrorMsg)
|
||||
_, err = io.Copy(wire, &errBuf)
|
||||
noerror(err)
|
||||
|
||||
err = wire.CloseWrite()
|
||||
noerror(err)
|
||||
|
||||
// drain wire, as documented in transport.Wire, this is the only way we know the client closed the conn
|
||||
_, err = io.Copy(ioutil.Discard, wire)
|
||||
if err != nil {
|
||||
// io.Copy masks io.EOF to nil, and we expect io.EOF from the client's Close() call
|
||||
log.Panicf("unexpected error returned from reading conn: %s", err)
|
||||
}
|
||||
|
||||
closeErr := wire.Close()
|
||||
log.Printf("closeErr=%T %s", closeErr, closeErr)
|
||||
|
||||
}
|
@ -0,0 +1,138 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type DeadlineMode uint
|
||||
|
||||
const (
|
||||
DeadlineModeClientTimeout DeadlineMode = 1 + iota
|
||||
DeadlineModeServerTimeout
|
||||
)
|
||||
|
||||
type Deadlines struct {
|
||||
mode DeadlineMode
|
||||
}
|
||||
|
||||
func (d Deadlines) Client(wire transport.Wire) {
|
||||
switch d.mode {
|
||||
case DeadlineModeClientTimeout:
|
||||
d.sleepThenSend(wire)
|
||||
case DeadlineModeServerTimeout:
|
||||
d.sendThenRead(wire)
|
||||
default:
|
||||
panic(d.mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Deadlines) Server(wire transport.Wire) {
|
||||
switch d.mode {
|
||||
case DeadlineModeClientTimeout:
|
||||
d.sendThenRead(wire)
|
||||
case DeadlineModeServerTimeout:
|
||||
d.sleepThenSend(wire)
|
||||
default:
|
||||
panic(d.mode)
|
||||
}
|
||||
}
|
||||
|
||||
var deadlinesTimeout = 1 * time.Second
|
||||
|
||||
func (d Deadlines) sleepThenSend(wire transport.Wire) {
|
||||
defer wire.Close()
|
||||
|
||||
log.Print("sleepThenSend")
|
||||
|
||||
// exceed timeout of peer (do not respond to their hi msg)
|
||||
time.Sleep(3 * deadlinesTimeout)
|
||||
// expect that the client has hung up on us by now
|
||||
err := d.sendMsg(wire, "hi")
|
||||
log.Printf("err=%s", err)
|
||||
log.Printf("err=%#v", err)
|
||||
if err == nil {
|
||||
log.Panic("no error")
|
||||
}
|
||||
if _, ok := err.(net.Error); !ok {
|
||||
log.Panic("not a net error")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (d Deadlines) sendThenRead(wire transport.Wire) {
|
||||
|
||||
log.Print("sendThenRead")
|
||||
|
||||
err := d.sendMsg(wire, "hi")
|
||||
noerror(err)
|
||||
|
||||
err = wire.SetReadDeadline(time.Now().Add(deadlinesTimeout))
|
||||
noerror(err)
|
||||
|
||||
m, err := d.recvMsg(wire)
|
||||
log.Printf("m=%q", m)
|
||||
log.Printf("err=%s", err)
|
||||
log.Printf("err=%#v", err)
|
||||
|
||||
// close asap so that the peer get's a 'connection reset by peer' error or similar
|
||||
closeErr := wire.Close()
|
||||
if closeErr != nil {
|
||||
panic(closeErr)
|
||||
}
|
||||
|
||||
var neterr net.Error
|
||||
var ok bool
|
||||
if err == nil {
|
||||
goto unexpErr // works for nil, too
|
||||
}
|
||||
neterr, ok = err.(net.Error)
|
||||
if !ok {
|
||||
log.Println("not a net error")
|
||||
goto unexpErr
|
||||
}
|
||||
if !neterr.Timeout() {
|
||||
log.Println("not a timeout")
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
unexpErr:
|
||||
panic(fmt.Sprintf("sendThenRead: client should have hung up but got error %T %s", err, err))
|
||||
}
|
||||
|
||||
const deadlinesMsgLen = 40
|
||||
|
||||
func (d Deadlines) sendMsg(wire transport.Wire, msg string) error {
|
||||
if len(msg) > deadlinesMsgLen {
|
||||
panic(len(msg))
|
||||
}
|
||||
var buf [deadlinesMsgLen]byte
|
||||
copy(buf[:], []byte(msg))
|
||||
n, err := wire.Write(buf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != len(buf) {
|
||||
panic("short write not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d Deadlines) recvMsg(wire transport.Wire) (string, error) {
|
||||
|
||||
var buf bytes.Buffer
|
||||
r := io.LimitReader(wire, deadlinesMsgLen)
|
||||
_, err := io.Copy(&buf, r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
|
||||
}
|
288
rpc/dataconn/timeoutconn/timeoutconn.go
Normal file
288
rpc/dataconn/timeoutconn/timeoutconn.go
Normal file
@ -0,0 +1,288 @@
|
||||
// package timeoutconn wraps a Wire to provide idle timeouts
|
||||
// based on Set{Read,Write}Deadline.
|
||||
// Additionally, it exports abstractions for vectored I/O.
|
||||
package timeoutconn
|
||||
|
||||
// NOTE
|
||||
// Readv and Writev are not split-off into a separate package
|
||||
// because we use raw syscalls, bypassing Conn's Read / Write methods.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Wire interface {
|
||||
net.Conn
|
||||
// A call to CloseWrite indicates that no further Write calls will be made to Wire.
|
||||
// The implementation must return an error in case of Write calls after CloseWrite.
|
||||
// On the peer's side, after it read all data written to Wire prior to the call to
|
||||
// CloseWrite on our side, the peer's Read calls must return io.EOF.
|
||||
// CloseWrite must not affect the read-direction of Wire: specifically, the
|
||||
// peer must continue to be able to send, and our side must continue be
|
||||
// able to receive data over Wire.
|
||||
//
|
||||
// Note that CloseWrite may (and most likely will) return sooner than the
|
||||
// peer having received all data written to Wire prior to CloseWrite.
|
||||
// Note further that buffering happening in the network stacks on either side
|
||||
// mandates an explicit acknowledgement from the peer that the connection may
|
||||
// be fully shut down: If we call Close without such acknowledgement, any data
|
||||
// from peer to us that was already in flight may cause connection resets to
|
||||
// be sent from us to the peer via the specific transport protocol. Those
|
||||
// resets (e.g. RST frames) may erase all connection context on the peer,
|
||||
// including data in its receive buffers. Thus, those resets are in race with
|
||||
// a) transmission of data written prior to CloseWrite and
|
||||
// b) the peer application reading from those buffers.
|
||||
//
|
||||
// The WaitForPeerClose method can be used to wait for connection termination,
|
||||
// iff the implementation supports it. If it does not, the only reliable way
|
||||
// to wait for a peer to have read all data from Wire (until io.EOF), is to
|
||||
// expect it to close the wire at that point as well, and to drain Wire until
|
||||
// we also read io.EOF.
|
||||
CloseWrite() error
|
||||
|
||||
// Wait for the peer to close the connection.
|
||||
// No data that could otherwise be Read is lost as a consequence of this call.
|
||||
// The use case for this API is abortive connection shutdown.
|
||||
// To provide any value over draining Wire using io.Read, an implementation
|
||||
// will likely use out-of-bounds messaging mechanisms.
|
||||
// TODO WaitForPeerClose() (supported bool, err error)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Wire
|
||||
renewDeadlinesDisabled int32
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
func Wrap(conn Wire, idleTimeout time.Duration) Conn {
|
||||
return Conn{Wire: conn, idleTimeout: idleTimeout}
|
||||
}
|
||||
|
||||
// DisableTimeouts disables the idle timeout behavior provided by this package.
|
||||
// Existing deadlines are cleared iff the call is the first call to this method.
|
||||
func (c *Conn) DisableTimeouts() error {
|
||||
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) {
|
||||
return c.SetDeadline(time.Time{})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) renewReadDeadline() error {
|
||||
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
||||
return nil
|
||||
}
|
||||
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
|
||||
func (c *Conn) renewWriteDeadline() error {
|
||||
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
||||
return nil
|
||||
}
|
||||
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
|
||||
func (c Conn) Read(p []byte) (n int, err error) {
|
||||
n = 0
|
||||
err = nil
|
||||
restart:
|
||||
if err := c.renewReadDeadline(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
var nCurRead int
|
||||
nCurRead, err = c.Wire.Read(p[n:len(p)])
|
||||
n += nCurRead
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 {
|
||||
err = nil
|
||||
goto restart
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c Conn) Write(p []byte) (n int, err error) {
|
||||
n = 0
|
||||
restart:
|
||||
if err := c.renewWriteDeadline(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
var nCurWrite int
|
||||
nCurWrite, err = c.Wire.Write(p[n:len(p)])
|
||||
n += nCurWrite
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
|
||||
err = nil
|
||||
goto restart
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Writes the given buffers to Conn, following the sematincs of io.Copy,
|
||||
// but is guaranteed to use the writev system call if the wrapped Wire
|
||||
// support it.
|
||||
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
|
||||
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
|
||||
n = 0
|
||||
restart:
|
||||
if err := c.renewWriteDeadline(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
var nCurWrite int64
|
||||
nCurWrite, err = io.Copy(c.Wire, &bufs)
|
||||
n += nCurWrite
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
|
||||
err = nil
|
||||
goto restart
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
var SyscallConnNotSupported = errors.New("SyscallConn not supported")
|
||||
|
||||
// The interface that must be implemented for vectored I/O support.
|
||||
// If the wrapped Wire does not implement it, a less efficient
|
||||
// fallback implementation is used.
|
||||
// Rest assured that Go's *net.TCPConn implements this interface.
|
||||
type SyscallConner interface {
|
||||
// The sentinel error value SyscallConnNotSupported can be returned
|
||||
// if the support for SyscallConn depends on runtime conditions and
|
||||
// that runtime condition is not met.
|
||||
SyscallConn() (syscall.RawConn, error)
|
||||
}
|
||||
|
||||
var _ SyscallConner = (*net.TCPConn)(nil)
|
||||
|
||||
func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
|
||||
vecs = make([]syscall.Iovec, 0, len(buffers))
|
||||
for i := range buffers {
|
||||
totalLen += int64(len(buffers[i]))
|
||||
if len(buffers[i]) == 0 {
|
||||
continue
|
||||
}
|
||||
vecs = append(vecs, syscall.Iovec{
|
||||
Base: &buffers[i][0],
|
||||
Len: uint64(len(buffers[i])),
|
||||
})
|
||||
}
|
||||
return totalLen, vecs
|
||||
}
|
||||
|
||||
// Reads the given buffers full:
|
||||
// Think of io.ReadvFull, but for net.Buffers + using the readv syscall.
|
||||
//
|
||||
// If the underlying Wire is not a SyscallConner, a fallback
|
||||
// ipmlementation based on repeated Conn.Read invocations is used.
|
||||
//
|
||||
// If the connection returned io.EOF, the number of bytes up ritten until
|
||||
// then + io.EOF is returned. This behavior is different to io.ReadFull
|
||||
// which returns io.ErrUnexpectedEOF.
|
||||
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
|
||||
totalLen, iovecs := buildIovecs(buffers)
|
||||
if debugReadvNoShortReadsAssertEnable {
|
||||
defer debugReadvNoShortReadsAssert(totalLen, n, err)
|
||||
}
|
||||
scc, ok := c.Wire.(SyscallConner)
|
||||
if !ok {
|
||||
return c.readvFallback(buffers)
|
||||
}
|
||||
raw, err := scc.SyscallConn()
|
||||
if err == SyscallConnNotSupported {
|
||||
return c.readvFallback(buffers)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = c.readv(raw, iovecs)
|
||||
return
|
||||
}
|
||||
|
||||
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
|
||||
buffers := [][]byte(nbuffers)
|
||||
for i := range buffers {
|
||||
curBuf := buffers[i]
|
||||
inner:
|
||||
for len(curBuf) > 0 {
|
||||
if err := c.renewReadDeadline(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
var oneN int
|
||||
oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING
|
||||
curBuf = curBuf[oneN:]
|
||||
n += int64(oneN)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 {
|
||||
continue inner
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c Conn) readv(rawConn syscall.RawConn, iovecs []syscall.Iovec) (n int64, err error) {
|
||||
for len(iovecs) > 0 {
|
||||
if err := c.renewReadDeadline(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
oneN, oneErr := c.doOneReadv(rawConn, &iovecs)
|
||||
n += oneN
|
||||
if netErr, ok := oneErr.(net.Error); ok && netErr.Timeout() && oneN > 0 { // TODO likely not working
|
||||
continue
|
||||
} else if oneErr == nil && oneN > 0 {
|
||||
continue
|
||||
} else {
|
||||
return n, oneErr
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
|
||||
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
|
||||
// iovecs, n and err must not be shadowed!
|
||||
thisReadN, _, errno := syscall.Syscall(
|
||||
syscall.SYS_READV,
|
||||
fd,
|
||||
uintptr(unsafe.Pointer(&(*iovecs)[0])),
|
||||
uintptr(len(*iovecs)),
|
||||
)
|
||||
if thisReadN == ^uintptr(0) {
|
||||
if errno == syscall.EAGAIN {
|
||||
return false
|
||||
}
|
||||
err = syscall.Errno(errno)
|
||||
return true
|
||||
}
|
||||
if int(thisReadN) < 0 {
|
||||
panic("unexpected return value")
|
||||
}
|
||||
n += int64(thisReadN)
|
||||
// shift iovecs forward
|
||||
for left := int64(thisReadN); left > 0; {
|
||||
curVecNewLength := int64((*iovecs)[0].Len) - left // TODO assert conversion
|
||||
if curVecNewLength <= 0 {
|
||||
left -= int64((*iovecs)[0].Len)
|
||||
*iovecs = (*iovecs)[1:]
|
||||
} else {
|
||||
(*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left)))
|
||||
(*iovecs)[0].Len = uint64(curVecNewLength)
|
||||
break // inner
|
||||
}
|
||||
}
|
||||
if thisReadN == 0 {
|
||||
err = io.EOF
|
||||
return true
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if rawReadErr != nil {
|
||||
err = rawReadErr
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
19
rpc/dataconn/timeoutconn/timeoutconn_debug.go
Normal file
19
rpc/dataconn/timeoutconn/timeoutconn_debug.go
Normal file
@ -0,0 +1,19 @@
|
||||
package timeoutconn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const debugReadvNoShortReadsAssertEnable = false
|
||||
|
||||
func debugReadvNoShortReadsAssert(expectedLen, returnedLen int64, returnedErr error) {
|
||||
readShort := expectedLen != returnedLen
|
||||
if !readShort {
|
||||
return
|
||||
}
|
||||
if returnedErr != io.EOF {
|
||||
return
|
||||
}
|
||||
panic(fmt.Sprintf("ReadvFull short and error is not EOF%v\n", returnedErr))
|
||||
}
|
177
rpc/dataconn/timeoutconn/timeoutconn_test.go
Normal file
177
rpc/dataconn/timeoutconn/timeoutconn_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
package timeoutconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zrepl/zrepl/util/socketpair"
|
||||
)
|
||||
|
||||
func TestReadTimeout(t *testing.T) {
|
||||
|
||||
a, b, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("tooktoolong")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
_, err := io.Copy(a, &buf)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := Wrap(b, 100*time.Millisecond)
|
||||
buf := [4]byte{} // shorter than message put on wire
|
||||
n, err := conn.Read(buf[:])
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Error(t, err)
|
||||
netErr, ok := err.(net.Error)
|
||||
require.True(t, ok)
|
||||
assert.True(t, netErr.Timeout())
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type writeBlockConn struct {
|
||||
net.Conn
|
||||
blockTime time.Duration
|
||||
}
|
||||
|
||||
func (c writeBlockConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(c.blockTime)
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c writeBlockConn) CloseWrite() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func TestWriteTimeout(t *testing.T) {
|
||||
a, b, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("message")
|
||||
blockConn := writeBlockConn{a, 500 * time.Millisecond}
|
||||
conn := Wrap(blockConn, 100*time.Millisecond)
|
||||
n, err := conn.Write(buf.Bytes())
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Error(t, err)
|
||||
netErr, ok := err.(net.Error)
|
||||
require.True(t, ok)
|
||||
assert.True(t, netErr.Timeout())
|
||||
}
|
||||
|
||||
func TestNoPartialReadsDueToDeadline(t *testing.T) {
|
||||
a, b, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
a.Write([]byte{1, 2, 3, 4, 5})
|
||||
// sleep to provoke a partial read in the consumer goroutine
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
a.Write([]byte{6, 7, 8, 9, 10})
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bc := Wrap(b, 100*time.Millisecond)
|
||||
var buf bytes.Buffer
|
||||
beginRead := time.Now()
|
||||
// io.Copy will encounter a partial read, then wait ~50ms until the other 5 bytes are written
|
||||
// It is still going to fail with deadline err because it expects EOF
|
||||
n, err := io.Copy(&buf, bc)
|
||||
readDuration := time.Now().Sub(beginRead)
|
||||
t.Logf("read duration=%s", readDuration)
|
||||
t.Logf("recv done n=%v err=%v", n, err)
|
||||
t.Logf("buf=%v", buf.Bytes())
|
||||
neterr, ok := err.(net.Error)
|
||||
require.True(t, ok)
|
||||
assert.True(t, neterr.Timeout())
|
||||
|
||||
assert.Equal(t, int64(10), n)
|
||||
assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, buf.Bytes())
|
||||
// 50ms for the second read, 100ms after that one for the deadline
|
||||
// allow for some jitter
|
||||
assert.True(t, readDuration > 140*time.Millisecond)
|
||||
assert.True(t, readDuration < 160*time.Millisecond)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type partialWriteMockConn struct {
|
||||
net.Conn // to satisfy interface
|
||||
buf bytes.Buffer
|
||||
writeDuration time.Duration
|
||||
returnAfterBytesWritten int
|
||||
}
|
||||
|
||||
func newPartialWriteMockConn(writeDuration time.Duration, returnAfterBytesWritten int) *partialWriteMockConn {
|
||||
return &partialWriteMockConn{
|
||||
writeDuration: writeDuration,
|
||||
returnAfterBytesWritten: returnAfterBytesWritten,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *partialWriteMockConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(c.writeDuration)
|
||||
consumeBytes := len(p)
|
||||
if consumeBytes > c.returnAfterBytesWritten {
|
||||
consumeBytes = c.returnAfterBytesWritten
|
||||
}
|
||||
n, err := c.buf.Write(p[0:consumeBytes])
|
||||
if err != nil || n != consumeBytes {
|
||||
panic("bytes.Buffer behaves unexpectedly")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func TestPartialWriteMockConn(t *testing.T) {
|
||||
mc := newPartialWriteMockConn(100*time.Millisecond, 5)
|
||||
buf := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
begin := time.Now()
|
||||
n, err := mc.Write(buf[:])
|
||||
duration := time.Now().Sub(begin)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, n)
|
||||
assert.True(t, duration > 100*time.Millisecond)
|
||||
assert.True(t, duration < 150*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestNoPartialWritesDueToDeadline(t *testing.T) {
|
||||
a, b, err := socketpair.SocketPair()
|
||||
require.NoError(t, err)
|
||||
defer a.Close()
|
||||
defer b.Close()
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("message")
|
||||
blockConn := writeBlockConn{a, 150 * time.Millisecond}
|
||||
conn := Wrap(blockConn, 100*time.Millisecond)
|
||||
n, err := conn.Write(buf.Bytes())
|
||||
assert.Equal(t, 0, n)
|
||||
assert.Error(t, err)
|
||||
netErr, ok := err.(net.Error)
|
||||
require.True(t, ok)
|
||||
assert.True(t, netErr.Timeout())
|
||||
}
|
118
rpc/grpcclientidentity/authlistener_grpc_adaptor.go
Normal file
118
rpc/grpcclientidentity/authlistener_grpc_adaptor.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Package grpcclientidentity makes the client identity
|
||||
// provided by github.com/zrepl/zrepl/daemon/transport/serve.{AuthenticatedListener,AuthConn}
|
||||
// available to gRPC service handlers.
|
||||
//
|
||||
// This goal is achieved through the combination of custom gRPC transport credentials and two interceptors
|
||||
// (i.e. middleware).
|
||||
//
|
||||
// For gRPC clients, the TransportCredentials + Dialer can be used to construct a gRPC client (grpc.ClientConn)
|
||||
// that uses a github.com/zrepl/zrepl/daemon/transport/connect.Connecter to connect to a server.
|
||||
//
|
||||
// The adaptors exposed by this package must be used together, and panic if they are not.
|
||||
// See package grpchelper for a more restrictive but safe example on how the adaptors should be composed.
|
||||
package grpcclientidentity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
type GRPCDialFunction = func(string, time.Duration) (net.Conn, error)
|
||||
|
||||
func NewDialer(logger Logger, connecter transport.Connecter) GRPCDialFunction {
|
||||
return func(s string, duration time.Duration) (conn net.Conn, e error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
nc, err := connecter.Connect(ctx)
|
||||
// TODO find better place (callback from gRPC?) where to log errors
|
||||
// we want the users to know, though
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("cannot connect")
|
||||
}
|
||||
return nc, err
|
||||
}
|
||||
}
|
||||
|
||||
type authConnAuthType struct {
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (authConnAuthType) AuthType() string {
|
||||
return "AuthConn"
|
||||
}
|
||||
|
||||
type connecterAuthType struct{}
|
||||
|
||||
func (connecterAuthType) AuthType() string {
|
||||
return "connecter"
|
||||
}
|
||||
|
||||
type transportCredentials struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Use on both sides as ServerOption or ClientOption.
|
||||
func NewTransportCredentials(log Logger) credentials.TransportCredentials {
|
||||
if log == nil {
|
||||
log = logger.NewNullLogger()
|
||||
}
|
||||
return &transportCredentials{log}
|
||||
}
|
||||
|
||||
func (c *transportCredentials) ClientHandshake(ctx context.Context, s string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
c.logger.WithField("url", s).WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ClientHandshake")
|
||||
// do nothing, client credential is only for WithInsecure warning to go away
|
||||
// the authentication is done by the connecter
|
||||
return rawConn, &connecterAuthType{}, nil
|
||||
}
|
||||
|
||||
func (c *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
c.logger.WithField("connType", fmt.Sprintf("%T", rawConn)).Debug("ServerHandshake")
|
||||
authConn, ok := rawConn.(*transport.AuthConn)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("NewTransportCredentials must be used with a listener that returns *transport.AuthConn, got %T", rawConn))
|
||||
}
|
||||
return rawConn, &authConnAuthType{authConn.ClientIdentity()}, nil
|
||||
}
|
||||
|
||||
func (*transportCredentials) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{} // TODO
|
||||
}
|
||||
|
||||
func (t *transportCredentials) Clone() credentials.TransportCredentials {
|
||||
var x = *t
|
||||
return &x
|
||||
}
|
||||
|
||||
func (*transportCredentials) OverrideServerName(string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func NewInterceptors(logger Logger, clientIdentityKey interface{}) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) {
|
||||
unary = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
logger.WithField("fullMethod", info.FullMethod).Debug("request")
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
panic("peer.FromContext expected to return a peer in grpc.UnaryServerInterceptor")
|
||||
}
|
||||
logger.WithField("peer", fmt.Sprintf("%v", p)).Debug("peer")
|
||||
a, ok := p.AuthInfo.(*authConnAuthType)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("NewInterceptors must be used in combination with grpc.NewTransportCredentials, but got auth type %T", p.AuthInfo))
|
||||
}
|
||||
ctx = context.WithValue(ctx, clientIdentityKey, a.clientIdentity)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
stream = nil
|
||||
return
|
||||
}
|
16
rpc/grpcclientidentity/example/grpcauth.proto
Normal file
16
rpc/grpcclientidentity/example/grpcauth.proto
Normal file
@ -0,0 +1,16 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package pdu;
|
||||
|
||||
|
||||
service Greeter {
|
||||
rpc Greet(GreetRequest) returns (GreetResponse) {}
|
||||
}
|
||||
|
||||
message GreetRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message GreetResponse {
|
||||
string msg = 1;
|
||||
}
|
107
rpc/grpcclientidentity/example/main.go
Normal file
107
rpc/grpcclientidentity/example/main.go
Normal file
@ -0,0 +1,107 @@
|
||||
// This package demonstrates how the grpcclientidentity package can be used
|
||||
// to set up a gRPC greeter service.
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/grpcclientidentity/example/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper"
|
||||
"github.com/zrepl/zrepl/transport/tcp"
|
||||
)
|
||||
|
||||
var args struct {
|
||||
mode string
|
||||
}
|
||||
|
||||
var log = logger.NewStderrDebugLogger()
|
||||
|
||||
func main() {
|
||||
flag.StringVar(&args.mode, "mode", "", "client|server")
|
||||
flag.Parse()
|
||||
|
||||
switch args.mode {
|
||||
case "client":
|
||||
client()
|
||||
case "server":
|
||||
server()
|
||||
default:
|
||||
log.Printf("unknown mode %q")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func onErr(err error, format string, args ...interface{}) {
|
||||
log.WithError(err).Error(fmt.Sprintf("%s: %s", fmt.Sprintf(format, args...), err))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func client() {
|
||||
cn, err := tcp.TCPConnecterFromConfig(&config.TCPConnect{
|
||||
ConnectCommon: config.ConnectCommon{
|
||||
Type: "tcp",
|
||||
},
|
||||
Address: "127.0.0.1:8080",
|
||||
DialTimeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
onErr(err, "build connecter error")
|
||||
}
|
||||
|
||||
clientConn := grpchelper.ClientConn(cn, log)
|
||||
defer clientConn.Close()
|
||||
|
||||
// normal usage from here on
|
||||
|
||||
client := pdu.NewGreeterClient(clientConn)
|
||||
resp, err := client.Greet(context.Background(), &pdu.GreetRequest{Name: "somethingimadeup"})
|
||||
if err != nil {
|
||||
onErr(err, "RPC error")
|
||||
}
|
||||
|
||||
fmt.Printf("got response:\n\t%s\n", resp.GetMsg())
|
||||
}
|
||||
|
||||
const clientIdentityKey = "clientIdentity"
|
||||
|
||||
func server() {
|
||||
authListenerFactory, err := tcp.TCPListenerFactoryFromConfig(nil, &config.TCPServe{
|
||||
ServeCommon: config.ServeCommon{
|
||||
Type: "tcp",
|
||||
},
|
||||
Listen: "127.0.0.1:8080",
|
||||
Clients: map[string]string{
|
||||
"127.0.0.1": "localclient",
|
||||
"::1": "localclient",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
onErr(err, "cannot build listener factory")
|
||||
}
|
||||
|
||||
log := logger.NewStderrDebugLogger()
|
||||
|
||||
srv, serve, err := grpchelper.NewServer(authListenerFactory, clientIdentityKey, log)
|
||||
svc := &greeter{"hello "}
|
||||
pdu.RegisterGreeterServer(srv, svc)
|
||||
|
||||
if err := serve(); err != nil {
|
||||
onErr(err, "error serving")
|
||||
}
|
||||
}
|
||||
|
||||
type greeter struct {
|
||||
prepend string
|
||||
}
|
||||
|
||||
func (g *greeter) Greet(ctx context.Context, r *pdu.GreetRequest) (*pdu.GreetResponse, error) {
|
||||
ci, _ := ctx.Value(clientIdentityKey).(string)
|
||||
log.WithField("clientIdentity", ci).Info("Greet() request") // show that we got the client identity
|
||||
return &pdu.GreetResponse{Msg: fmt.Sprintf("%s%s (clientIdentity=%q)", g.prepend, r.GetName(), ci)}, nil
|
||||
}
|
193
rpc/grpcclientidentity/example/pdu/grpcauth.pb.go
Normal file
193
rpc/grpcclientidentity/example/pdu/grpcauth.pb.go
Normal file
@ -0,0 +1,193 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: grpcauth.proto
|
||||
|
||||
package pdu
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type GreetRequest struct {
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *GreetRequest) Reset() { *m = GreetRequest{} }
|
||||
func (m *GreetRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*GreetRequest) ProtoMessage() {}
|
||||
func (*GreetRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_1dfba7be0cf69353, []int{0}
|
||||
}
|
||||
|
||||
func (m *GreetRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_GreetRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *GreetRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_GreetRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *GreetRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_GreetRequest.Merge(m, src)
|
||||
}
|
||||
func (m *GreetRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_GreetRequest.Size(m)
|
||||
}
|
||||
func (m *GreetRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_GreetRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_GreetRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *GreetRequest) GetName() string {
|
||||
if m != nil {
|
||||
return m.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type GreetResponse struct {
|
||||
Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *GreetResponse) Reset() { *m = GreetResponse{} }
|
||||
func (m *GreetResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*GreetResponse) ProtoMessage() {}
|
||||
func (*GreetResponse) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_1dfba7be0cf69353, []int{1}
|
||||
}
|
||||
|
||||
func (m *GreetResponse) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_GreetResponse.Unmarshal(m, b)
|
||||
}
|
||||
func (m *GreetResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_GreetResponse.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *GreetResponse) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_GreetResponse.Merge(m, src)
|
||||
}
|
||||
func (m *GreetResponse) XXX_Size() int {
|
||||
return xxx_messageInfo_GreetResponse.Size(m)
|
||||
}
|
||||
func (m *GreetResponse) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_GreetResponse.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_GreetResponse proto.InternalMessageInfo
|
||||
|
||||
func (m *GreetResponse) GetMsg() string {
|
||||
if m != nil {
|
||||
return m.Msg
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*GreetRequest)(nil), "pdu.GreetRequest")
|
||||
proto.RegisterType((*GreetResponse)(nil), "pdu.GreetResponse")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("grpcauth.proto", fileDescriptor_1dfba7be0cf69353) }
|
||||
|
||||
var fileDescriptor_1dfba7be0cf69353 = []byte{
|
||||
// 137 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4b, 0x2f, 0x2a, 0x48,
|
||||
0x4e, 0x2c, 0x2d, 0xc9, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2e, 0x48, 0x29, 0x55,
|
||||
0x52, 0xe2, 0xe2, 0x71, 0x2f, 0x4a, 0x4d, 0x2d, 0x09, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11,
|
||||
0x12, 0xe2, 0x62, 0xc9, 0x4b, 0xcc, 0x4d, 0x95, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3,
|
||||
0x95, 0x14, 0xb9, 0x78, 0xa1, 0x6a, 0x8a, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0x04, 0xb8, 0x98,
|
||||
0x73, 0x8b, 0xd3, 0xa1, 0x6a, 0x40, 0x4c, 0x23, 0x6b, 0x2e, 0x76, 0xb0, 0x92, 0xd4, 0x22, 0x21,
|
||||
0x03, 0x2e, 0x56, 0x30, 0x53, 0x48, 0x50, 0xaf, 0x20, 0xa5, 0x54, 0x0f, 0xd9, 0x74, 0x29, 0x21,
|
||||
0x64, 0x21, 0x88, 0x61, 0x4a, 0x0c, 0x49, 0x6c, 0x60, 0xf7, 0x18, 0x03, 0x02, 0x00, 0x00, 0xff,
|
||||
0xff, 0xa8, 0x53, 0x2f, 0x4c, 0xa1, 0x00, 0x00, 0x00,
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// GreeterClient is the client API for Greeter service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type GreeterClient interface {
|
||||
Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error)
|
||||
}
|
||||
|
||||
type greeterClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
|
||||
return &greeterClient{cc}
|
||||
}
|
||||
|
||||
func (c *greeterClient) Greet(ctx context.Context, in *GreetRequest, opts ...grpc.CallOption) (*GreetResponse, error) {
|
||||
out := new(GreetResponse)
|
||||
err := c.cc.Invoke(ctx, "/pdu.Greeter/Greet", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GreeterServer is the server API for Greeter service.
|
||||
type GreeterServer interface {
|
||||
Greet(context.Context, *GreetRequest) (*GreetResponse, error)
|
||||
}
|
||||
|
||||
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
|
||||
s.RegisterService(&_Greeter_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Greeter_Greet_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GreetRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(GreeterServer).Greet(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/pdu.Greeter/Greet",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(GreeterServer).Greet(ctx, req.(*GreetRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _Greeter_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "pdu.Greeter",
|
||||
HandlerType: (*GreeterServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Greet",
|
||||
Handler: _Greeter_Greet_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "grpcauth.proto",
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
// Package grpchelper wraps the adaptors implemented by package grpcclientidentity into a less flexible API
|
||||
// which, however, ensures that the individual adaptor primitive's expectations are met and hence do not panic.
|
||||
package grpchelper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/grpcclientidentity"
|
||||
"github.com/zrepl/zrepl/rpc/netadaptor"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
// The following constants are relevant for interoperability.
|
||||
// We use the same values for client & server, because zrepl is more
|
||||
// symmetrical ("one source, one sink") instead of the typical
|
||||
// gRPC scenario ("many clients, single server")
|
||||
const (
|
||||
StartKeepalivesAfterInactivityDuration = 5 * time.Second
|
||||
KeepalivePeerTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
// ClientConn is an easy-to-use wrapper around the Dialer and TransportCredentials interface
|
||||
// to produce a grpc.ClientConn
|
||||
func ClientConn(cn transport.Connecter, log Logger) *grpc.ClientConn {
|
||||
ka := grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: StartKeepalivesAfterInactivityDuration,
|
||||
Timeout: KeepalivePeerTimeout,
|
||||
PermitWithoutStream: true,
|
||||
})
|
||||
dialerOption := grpc.WithDialer(grpcclientidentity.NewDialer(log, cn))
|
||||
cred := grpc.WithTransportCredentials(grpcclientidentity.NewTransportCredentials(log))
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
|
||||
defer cancel()
|
||||
cc, err := grpc.DialContext(ctx, "doesn't matter done by dialer", dialerOption, cred, ka)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("cannot create gRPC client conn (non-blocking)")
|
||||
// It's ok to panic here: the we call grpc.DialContext without the
|
||||
// (grpc.WithBlock) dial option, and at the time of writing, the grpc
|
||||
// docs state that no connection attempt is made in that case.
|
||||
// Hence, any error that occurs is due to DialOptions or similar,
|
||||
// and thus indicative of an implementation error.
|
||||
panic(err)
|
||||
}
|
||||
return cc
|
||||
}
|
||||
|
||||
// NewServer is a convenience interface around the TransportCredentials and Interceptors interface.
|
||||
func NewServer(authListenerFactory transport.AuthenticatedListenerFactory, clientIdentityKey interface{}, logger grpcclientidentity.Logger) (srv *grpc.Server, serve func() error, err error) {
|
||||
ka := grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||
Time: StartKeepalivesAfterInactivityDuration,
|
||||
Timeout: KeepalivePeerTimeout,
|
||||
})
|
||||
tcs := grpcclientidentity.NewTransportCredentials(logger)
|
||||
unary, stream := grpcclientidentity.NewInterceptors(logger, clientIdentityKey)
|
||||
srv = grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream), ka)
|
||||
|
||||
serve = func() error {
|
||||
l, err := authListenerFactory()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := srv.Serve(netadaptor.New(l, logger)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return srv, serve, nil
|
||||
}
|
102
rpc/netadaptor/authlistener_netlistener_adaptor.go
Normal file
102
rpc/netadaptor/authlistener_netlistener_adaptor.go
Normal file
@ -0,0 +1,102 @@
|
||||
// Package netadaptor implements an adaptor from
|
||||
// transport.AuthenticatedListener to net.Listener.
|
||||
//
|
||||
// In contrast to transport.AuthenticatedListener,
|
||||
// net.Listener is commonly expected (e.g. by net/http.Server.Serve),
|
||||
// to return errors that fulfill the Temporary interface:
|
||||
// interface Temporary { Temporary() bool }
|
||||
// Common behavior of packages consuming net.Listener is to return
|
||||
// from the serve-loop if an error returned by Accept is not Temporary,
|
||||
// i.e., does not implement the interface or is !net.Error.Temporary().
|
||||
//
|
||||
// The zrepl transport infrastructure was written with the
|
||||
// idea that Accept() may return any kind of error, and the consumer
|
||||
// would just log the error and continue calling Accept().
|
||||
// We have to adapt these listeners' behavior to the expectations
|
||||
// of net/http.Server.
|
||||
//
|
||||
// Hence, Listener does not return an error at all but blocks the
|
||||
// caller of Accept() until we get a (successfully authenticated)
|
||||
// connection without errors from the transport.
|
||||
// Accept errors returned from the transport are logged as errors
|
||||
// to the logger passed on initialization.
|
||||
package netadaptor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
type acceptReq struct {
|
||||
callback chan net.Conn
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
al transport.AuthenticatedListener
|
||||
logger Logger
|
||||
accepts chan acceptReq
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
// Consume the given authListener and wrap it as a *Listener, which implements net.Listener.
|
||||
// The returned net.Listener must be closed.
|
||||
// The wrapped authListener is closed when the returned net.Listener is closed.
|
||||
func New(authListener transport.AuthenticatedListener, l Logger) *Listener {
|
||||
if l == nil {
|
||||
l = logger.NewNullLogger()
|
||||
}
|
||||
a := &Listener{
|
||||
al: authListener,
|
||||
logger: l,
|
||||
accepts: make(chan acceptReq),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
go a.handleAccept()
|
||||
return a
|
||||
}
|
||||
|
||||
// The returned net.Conn is guaranteed to be *transport.AuthConn, i.e., the type of connection
|
||||
// returned by the wrapped transport.AuthenticatedListener.
|
||||
func (a Listener) Accept() (net.Conn, error) {
|
||||
req := acceptReq{make(chan net.Conn, 1)}
|
||||
a.accepts <- req
|
||||
conn := <-req.callback
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (a Listener) handleAccept() {
|
||||
for {
|
||||
select {
|
||||
case <-a.stop:
|
||||
a.logger.Debug("handleAccept stop accepting")
|
||||
return
|
||||
case req := <-a.accepts:
|
||||
for {
|
||||
a.logger.Debug("accept authListener")
|
||||
authConn, err := a.al.Accept(context.Background())
|
||||
if err != nil {
|
||||
a.logger.WithError(err).Error("accept error")
|
||||
continue
|
||||
}
|
||||
a.logger.WithField("type", fmt.Sprintf("%T", authConn)).
|
||||
Debug("accept complete")
|
||||
req.callback <- authConn
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a Listener) Addr() net.Addr {
|
||||
return a.al.Addr()
|
||||
}
|
||||
|
||||
func (a Listener) Close() error {
|
||||
close(a.stop)
|
||||
return a.al.Close()
|
||||
}
|
106
rpc/rpc_client.go
Normal file
106
rpc/rpc_client.go
Normal file
@ -0,0 +1,106 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zrepl/zrepl/replication"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn"
|
||||
"github.com/zrepl/zrepl/rpc/grpcclientidentity/grpchelper"
|
||||
"github.com/zrepl/zrepl/rpc/versionhandshake"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// Client implements the active side of a replication setup.
|
||||
// It satisfies the Endpoint, Sender and Receiver interface defined by package replication.
|
||||
type Client struct {
|
||||
dataClient *dataconn.Client
|
||||
controlClient pdu.ReplicationClient // this the grpc client instance, see constructor
|
||||
controlConn *grpc.ClientConn
|
||||
loggers Loggers
|
||||
}
|
||||
|
||||
var _ replication.Endpoint = &Client{}
|
||||
var _ replication.Sender = &Client{}
|
||||
var _ replication.Receiver = &Client{}
|
||||
|
||||
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
|
||||
// config must be validated, NewClient will panic if it is not valid
|
||||
func NewClient(cn transport.Connecter, loggers Loggers) *Client {
|
||||
|
||||
cn = versionhandshake.Connecter(cn, envconst.Duration("ZREPL_RPC_CLIENT_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second))
|
||||
|
||||
muxedConnecter := mux(cn)
|
||||
|
||||
c := &Client{
|
||||
loggers: loggers,
|
||||
}
|
||||
grpcConn := grpchelper.ClientConn(muxedConnecter.control, loggers.Control)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
state := grpcConn.GetState()
|
||||
loggers.General.WithField("grpc_state", state.String()).Debug("grpc state change")
|
||||
grpcConn.WaitForStateChange(context.TODO(), state)
|
||||
}
|
||||
}()
|
||||
c.controlClient = pdu.NewReplicationClient(grpcConn)
|
||||
c.controlConn = grpcConn
|
||||
|
||||
c.dataClient = dataconn.NewClient(muxedConnecter.data, loggers.Data)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
if err := c.controlConn.Close(); err != nil {
|
||||
c.loggers.General.WithError(err).Error("cannot cloe control connection")
|
||||
}
|
||||
// TODO c.dataClient should have Close()
|
||||
}
|
||||
|
||||
// callers must ensure that the returned io.ReadCloser is closed
|
||||
// TODO expose dataClient interface to the outside world
|
||||
func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
// TODO the returned sendStream may return a read error created by the remote side
|
||||
res, streamCopier, err := c.dataClient.ReqSend(ctx, r)
|
||||
if err != nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if streamCopier == nil {
|
||||
return res, nil, nil
|
||||
}
|
||||
|
||||
return res, streamCopier, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
return c.dataClient.ReqRecv(ctx, req, streamCopier)
|
||||
}
|
||||
|
||||
func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
|
||||
return c.controlClient.ListFilesystems(ctx, in)
|
||||
}
|
||||
|
||||
func (c *Client) ListFilesystemVersions(ctx context.Context, in *pdu.ListFilesystemVersionsReq) (*pdu.ListFilesystemVersionsRes, error) {
|
||||
return c.controlClient.ListFilesystemVersions(ctx, in)
|
||||
}
|
||||
|
||||
func (c *Client) DestroySnapshots(ctx context.Context, in *pdu.DestroySnapshotsReq) (*pdu.DestroySnapshotsRes, error) {
|
||||
return c.controlClient.DestroySnapshots(ctx, in)
|
||||
}
|
||||
|
||||
func (c *Client) ReplicationCursor(ctx context.Context, in *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) {
|
||||
return c.controlClient.ReplicationCursor(ctx, in)
|
||||
}
|
||||
|
||||
func (c *Client) ResetConnectBackoff() {
|
||||
c.controlConn.ResetConnectBackoff()
|
||||
}
|
20
rpc/rpc_debug.go
Normal file
20
rpc/rpc_debug.go
Normal file
@ -0,0 +1,20 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var debugEnabled bool = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("ZREPL_RPC_DEBUG") != "" {
|
||||
debugEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func debug(format string, args ...interface{}) {
|
||||
if debugEnabled {
|
||||
fmt.Fprintf(os.Stderr, "rpc: %s\n", fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
118
rpc/rpc_doc.go
Normal file
118
rpc/rpc_doc.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Package rpc implements zrepl daemon-to-daemon RPC protocol
|
||||
// on top of a transport provided by package transport.
|
||||
// The zrepl documentation refers to the client as the
|
||||
// `active side` and to the server as the `passive side`.
|
||||
//
|
||||
// Design Considerations
|
||||
//
|
||||
// zrepl has non-standard requirements to remote procedure calls (RPC):
|
||||
// whereas the coordination of replication (the planning phase) mostly
|
||||
// consists of regular unary calls, the actual data transfer requires
|
||||
// a high-throughput, low-overhead solution.
|
||||
//
|
||||
// Specifically, the requirements for data transfer is to perform
|
||||
// a copy of an io.Reader over the wire, such that an io.EOF of the original
|
||||
// reader corresponds to an io.EOF on the receiving side.
|
||||
// If any other error occurs while reading from the original io.Reader
|
||||
// on the sender, the receiver should receive the contents of that error
|
||||
// in some form (usually as a trailer message)
|
||||
// A common implementation technique for above data transfer is chunking,
|
||||
// for example in HTTP:
|
||||
// https://tools.ietf.org/html/rfc2616#section-3.6.1
|
||||
//
|
||||
// With regards to regular RPCs, gRPC is a popular implementation based
|
||||
// on protocol buffers and thus code generation.
|
||||
// gRPC also supports bi-directional streaming RPC, and it is possible
|
||||
// to implement chunked transfer through the use of streams.
|
||||
//
|
||||
// For zrepl however, both HTTP and manually implemented chunked transfer
|
||||
// using gRPC were found to have significant CPU overhead at transfer
|
||||
// speeds to be expected even with hobbyist users.
|
||||
//
|
||||
// However, it is nice to piggyback on the code generation provided
|
||||
// by protobuf / gRPC, in particular since the majority of call types
|
||||
// are regular unary RPCs for which the higher overhead of gRPC is acceptable.
|
||||
//
|
||||
// Hence, this package attempts to combine the best of both worlds:
|
||||
//
|
||||
// GRPC for Coordination and Dataconn for Bulk Data Transfer
|
||||
//
|
||||
// This package's Client uses its transport.Connecter to maintain
|
||||
// separate control and data connections to the Server.
|
||||
// The control connection is used by an instance of pdu.ReplicationClient
|
||||
// whose 'regular' unary RPC calls are re-exported.
|
||||
// The data connection is used by an instance of dataconn.Client and
|
||||
// is used for bulk data transfer, namely `Send` and `Receive`.
|
||||
//
|
||||
// The following ASCII diagram gives an overview of how the individual
|
||||
// building blocks are glued together:
|
||||
//
|
||||
// +------------+
|
||||
// | rpc.Client |
|
||||
// +------------+
|
||||
// | |
|
||||
// +--------+ +------------+
|
||||
// | |
|
||||
// +---------v-----------+ +--------v------+
|
||||
// |pdu.ReplicationClient| |dataconn.Client|
|
||||
// +---------------------+ +--------v------+
|
||||
// | label: label: |
|
||||
// | zrepl_control zrepl_data |
|
||||
// +--------+ +------------+
|
||||
// | |
|
||||
// +--v---------v---+
|
||||
// | transportmux |
|
||||
// +-------+--------+
|
||||
// | uses
|
||||
// +-------v--------+
|
||||
// |versionhandshake|
|
||||
// +-------+--------+
|
||||
// | uses
|
||||
// +------v------+
|
||||
// | transport |
|
||||
// +------+------+
|
||||
// |
|
||||
// NETWORK
|
||||
// |
|
||||
// +------+------+
|
||||
// | transport |
|
||||
// +------^------+
|
||||
// | uses
|
||||
// +-------+--------+
|
||||
// |versionhandshake|
|
||||
// +-------^--------+
|
||||
// | uses
|
||||
// +-------+--------+
|
||||
// | transportmux |
|
||||
// +--^--------^----+
|
||||
// | |
|
||||
// +--------+ --------------+ ---
|
||||
// | | |
|
||||
// | label: label: | |
|
||||
// | zrepl_control zrepl_data | |
|
||||
// +-----+----+ +-----------+---+ |
|
||||
// |netadaptor| |dataconn.Server| | rpc.Server
|
||||
// | + | +------+--------+ |
|
||||
// |grpcclient| | |
|
||||
// |identity | | |
|
||||
// +-----+----+ | |
|
||||
// | | |
|
||||
// +---------v-----------+ | |
|
||||
// |pdu.ReplicationServer| | |
|
||||
// +---------+-----------+ | |
|
||||
// | | ---
|
||||
// +----------+ +------------+
|
||||
// | |
|
||||
// +---v--v-----+
|
||||
// | Handler |
|
||||
// +------------+
|
||||
// (usually endpoint.{Sender,Receiver})
|
||||
//
|
||||
//
|
||||
package rpc
|
||||
|
||||
// edit trick for the ASCII art above:
|
||||
// - remove the comments //
|
||||
// - vim: set virtualedit+=all
|
||||
// - vim: set ft=text
|
||||
|
34
rpc/rpc_logging.go
Normal file
34
rpc/rpc_logging.go
Normal file
@ -0,0 +1,34 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyLoggers contextKey = iota
|
||||
contextKeyGeneralLogger
|
||||
contextKeyControlLogger
|
||||
contextKeyDataLogger
|
||||
)
|
||||
|
||||
/// All fields must be non-nil
|
||||
type Loggers struct {
|
||||
General Logger
|
||||
Control Logger
|
||||
Data Logger
|
||||
}
|
||||
|
||||
func WithLoggers(ctx context.Context, loggers Loggers) context.Context {
|
||||
ctx = context.WithValue(ctx, contextKeyLoggers, loggers)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func GetLoggersOrPanic(ctx context.Context) Loggers {
|
||||
return ctx.Value(contextKeyLoggers).(Loggers)
|
||||
}
|
57
rpc/rpc_mux.go
Normal file
57
rpc/rpc_mux.go
Normal file
@ -0,0 +1,57 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/rpc/transportmux"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
)
|
||||
|
||||
type demuxedListener struct {
|
||||
control, data transport.AuthenticatedListener
|
||||
}
|
||||
|
||||
const (
|
||||
transportmuxLabelControl string = "zrepl_control"
|
||||
transportmuxLabelData string = "zrepl_data"
|
||||
)
|
||||
|
||||
func demux(serveCtx context.Context, listener transport.AuthenticatedListener) demuxedListener {
|
||||
listeners, err := transportmux.Demux(
|
||||
serveCtx, listener,
|
||||
[]string{transportmuxLabelControl, transportmuxLabelData},
|
||||
envconst.Duration("ZREPL_TRANSPORT_DEMUX_TIMEOUT", 10*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
// transportmux API guarantees that the returned error can only be due
|
||||
// to invalid API usage (i.e. labels too long)
|
||||
panic(err)
|
||||
}
|
||||
return demuxedListener{
|
||||
control: listeners[transportmuxLabelControl],
|
||||
data: listeners[transportmuxLabelData],
|
||||
}
|
||||
}
|
||||
|
||||
type muxedConnecter struct {
|
||||
control, data transport.Connecter
|
||||
}
|
||||
|
||||
func mux(rawConnecter transport.Connecter) muxedConnecter {
|
||||
muxedConnecters, err := transportmux.MuxConnecter(
|
||||
rawConnecter,
|
||||
[]string{transportmuxLabelControl, transportmuxLabelData},
|
||||
envconst.Duration("ZREPL_TRANSPORT_MUX_TIMEOUT", 10*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
// transportmux API guarantees that the returned error can only be due
|
||||
// to invalid API usage (i.e. labels too long)
|
||||
panic(err)
|
||||
}
|
||||
return muxedConnecter{
|
||||
control: muxedConnecters[transportmuxLabelControl],
|
||||
data: muxedConnecters[transportmuxLabelData],
|
||||
}
|
||||
}
|
119
rpc/rpc_server.go
Normal file
119
rpc/rpc_server.go
Normal file
@ -0,0 +1,119 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zrepl/zrepl/endpoint"
|
||||
"github.com/zrepl/zrepl/replication/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn"
|
||||
"github.com/zrepl/zrepl/rpc/grpcclientidentity"
|
||||
"github.com/zrepl/zrepl/rpc/netadaptor"
|
||||
"github.com/zrepl/zrepl/rpc/versionhandshake"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
pdu.ReplicationServer
|
||||
dataconn.Handler
|
||||
}
|
||||
|
||||
type serveFunc func(ctx context.Context, demuxedListener transport.AuthenticatedListener, errOut chan<- error)
|
||||
|
||||
// Server abstracts the accept and request routing infrastructure for the
|
||||
// passive side of a replication setup.
|
||||
type Server struct {
|
||||
logger Logger
|
||||
handler Handler
|
||||
controlServer *grpc.Server
|
||||
controlServerServe serveFunc
|
||||
dataServer *dataconn.Server
|
||||
dataServerServe serveFunc
|
||||
}
|
||||
|
||||
type serverContextKey int
|
||||
|
||||
type HandlerContextInterceptor func(ctx context.Context) context.Context
|
||||
|
||||
// config must be valid (use its Validate function).
|
||||
func NewServer(handler Handler, loggers Loggers, ctxInterceptor HandlerContextInterceptor) *Server {
|
||||
|
||||
// setup control server
|
||||
tcs := grpcclientidentity.NewTransportCredentials(loggers.Control) // TODO different subsystem for log
|
||||
unary, stream := grpcclientidentity.NewInterceptors(loggers.Control, endpoint.ClientIdentityKey)
|
||||
controlServer := grpc.NewServer(grpc.Creds(tcs), grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream))
|
||||
pdu.RegisterReplicationServer(controlServer, handler)
|
||||
controlServerServe := func(ctx context.Context, controlListener transport.AuthenticatedListener, errOut chan<- error) {
|
||||
// give time for graceful stop until deadline expires, then hard stop
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
go time.AfterFunc(dl.Sub(dl), controlServer.Stop)
|
||||
}
|
||||
loggers.Control.Debug("shutting down control server")
|
||||
controlServer.GracefulStop()
|
||||
}()
|
||||
|
||||
errOut <- controlServer.Serve(netadaptor.New(controlListener, loggers.Control))
|
||||
}
|
||||
|
||||
// setup data server
|
||||
dataServerClientIdentitySetter := func(ctx context.Context, wire *transport.AuthConn) (context.Context, *transport.AuthConn) {
|
||||
ci := wire.ClientIdentity()
|
||||
ctx = context.WithValue(ctx, endpoint.ClientIdentityKey, ci)
|
||||
if ctxInterceptor != nil {
|
||||
ctx = ctxInterceptor(ctx) // SHADOWING
|
||||
}
|
||||
return ctx, wire
|
||||
}
|
||||
dataServer := dataconn.NewServer(dataServerClientIdentitySetter, loggers.Data, handler)
|
||||
dataServerServe := func(ctx context.Context, dataListener transport.AuthenticatedListener, errOut chan<- error) {
|
||||
dataServer.Serve(ctx, dataListener)
|
||||
errOut <- nil // TODO bad design of dataServer?
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
logger: loggers.General,
|
||||
handler: handler,
|
||||
controlServer: controlServer,
|
||||
controlServerServe: controlServerServe,
|
||||
dataServer: dataServer,
|
||||
dataServerServe: dataServerServe,
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// The context is used for cancellation only.
|
||||
// Serve never returns an error, it logs them to the Server's logger.
|
||||
func (s *Server) Serve(ctx context.Context, l transport.AuthenticatedListener) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
l = versionhandshake.Listener(l, envconst.Duration("ZREPL_RPC_SERVER_VERSIONHANDSHAKE_TIMEOUT", 10*time.Second))
|
||||
|
||||
// it is important that demux's context is cancelled,
|
||||
// it has background goroutines attached
|
||||
demuxListener := demux(ctx, l)
|
||||
|
||||
serveErrors := make(chan error, 2)
|
||||
go s.controlServerServe(ctx, demuxListener.control, serveErrors)
|
||||
go s.dataServerServe(ctx, demuxListener.data, serveErrors)
|
||||
select {
|
||||
case serveErr := <-serveErrors:
|
||||
s.logger.WithError(serveErr).Error("serve error")
|
||||
s.logger.Debug("wait for other server to shut down")
|
||||
cancel()
|
||||
secondServeErr := <-serveErrors
|
||||
s.logger.WithError(secondServeErr).Error("serve error")
|
||||
case <-ctx.Done():
|
||||
s.logger.Debug("context cancelled, wait for control and data servers")
|
||||
cancel()
|
||||
for i := 0; i < 2; i++ {
|
||||
<-serveErrors
|
||||
}
|
||||
s.logger.Debug("control and data server shut down, returning from Serve")
|
||||
}
|
||||
}
|
205
rpc/transportmux/transportmux.go
Normal file
205
rpc/transportmux/transportmux.go
Normal file
@ -0,0 +1,205 @@
|
||||
// Package transportmux wraps a transport.{Connecter,AuthenticatedListener}
|
||||
// to distinguish different connection types based on a label
|
||||
// sent from client to server on connection establishment.
|
||||
//
|
||||
// Labels are plain text and fixed length.
|
||||
package transportmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
contextKeyLog contextKey = 1 + iota
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLog, log)
|
||||
}
|
||||
|
||||
func getLog(ctx context.Context) Logger {
|
||||
if l, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
||||
return l
|
||||
}
|
||||
return logger.NewNullLogger()
|
||||
}
|
||||
|
||||
type acceptRes struct {
|
||||
conn *transport.AuthConn
|
||||
err error
|
||||
}
|
||||
|
||||
type demuxListener struct {
|
||||
conns chan acceptRes
|
||||
}
|
||||
|
||||
func (l *demuxListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
res := <-l.conns
|
||||
return res.conn, res.err
|
||||
}
|
||||
|
||||
type demuxAddr struct {}
|
||||
|
||||
func (demuxAddr) Network() string { return "demux" }
|
||||
func (demuxAddr) String() string { return "demux" }
|
||||
|
||||
func (l *demuxListener) Addr() net.Addr {
|
||||
return demuxAddr{}
|
||||
}
|
||||
|
||||
func (l *demuxListener) Close() error { return nil } // TODO
|
||||
|
||||
// Exact length of a label in bytes (0-byte padded if it is shorter).
|
||||
// This is a protocol constant, changing it breaks the wire protocol.
|
||||
const LabelLen = 64
|
||||
|
||||
func padLabel(out []byte, label string) (error) {
|
||||
if len(label) > LabelLen {
|
||||
return fmt.Errorf("label %q exceeds max length (is %d, max %d)", label, len(label), LabelLen)
|
||||
}
|
||||
if len(out) != LabelLen {
|
||||
panic(fmt.Sprintf("implementation error: %d", out))
|
||||
}
|
||||
labelBytes := []byte(label)
|
||||
copy(out[:], labelBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func Demux(ctx context.Context, rawListener transport.AuthenticatedListener, labels []string, timeout time.Duration) (map[string]transport.AuthenticatedListener, error) {
|
||||
|
||||
padded := make(map[[64]byte]*demuxListener, len(labels))
|
||||
ret := make(map[string]transport.AuthenticatedListener, len(labels))
|
||||
for _, label := range labels {
|
||||
var labelPadded [LabelLen]byte
|
||||
err := padLabel(labelPadded[:], label)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := padded[labelPadded]; ok {
|
||||
return nil, fmt.Errorf("duplicate label %q", label)
|
||||
}
|
||||
dl := &demuxListener{make(chan acceptRes)}
|
||||
padded[labelPadded] = dl
|
||||
ret[label] = dl
|
||||
}
|
||||
|
||||
// invariant: padded contains same-length, non-duplicate labels
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
getLog(ctx).Debug("context cancelled, closing listener")
|
||||
if err := rawListener.Close(); err != nil {
|
||||
getLog(ctx).WithError(err).Error("error closing listener")
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
rawConn, err := rawListener.Accept(ctx)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
getLog(ctx).WithError(err).Error("accept error")
|
||||
continue
|
||||
}
|
||||
closeConn := func() {
|
||||
if err := rawConn.Close(); err != nil {
|
||||
getLog(ctx).WithError(err).Error("cannot close conn")
|
||||
}
|
||||
}
|
||||
|
||||
if err := rawConn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
getLog(ctx).WithError(err).Error("SetDeadline failed")
|
||||
closeConn()
|
||||
continue
|
||||
}
|
||||
|
||||
var labelBuf [LabelLen]byte
|
||||
if _, err := io.ReadFull(rawConn, labelBuf[:]); err != nil {
|
||||
getLog(ctx).WithError(err).Error("error reading label")
|
||||
closeConn()
|
||||
continue
|
||||
}
|
||||
|
||||
demuxListener, ok := padded[labelBuf]
|
||||
if !ok {
|
||||
getLog(ctx).WithError(err).
|
||||
WithField("client_label", fmt.Sprintf("%q", labelBuf)).
|
||||
Error("unknown client label")
|
||||
closeConn()
|
||||
continue
|
||||
}
|
||||
|
||||
rawConn.SetDeadline(time.Time{})
|
||||
// blocking is intentional
|
||||
demuxListener.conns <- acceptRes{conn: rawConn, err: nil}
|
||||
}
|
||||
}()
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type labeledConnecter struct {
|
||||
label []byte
|
||||
transport.Connecter
|
||||
}
|
||||
|
||||
func (c labeledConnecter) Connect(ctx context.Context) (transport.Wire, error) {
|
||||
conn, err := c.Connecter.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
closeConn := func(why error) {
|
||||
getLog(ctx).WithField("reason", why.Error()).Debug("closing connection")
|
||||
if err := conn.Close(); err != nil {
|
||||
getLog(ctx).WithError(err).Error("error closing connection after label write error")
|
||||
}
|
||||
}
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
if err := conn.SetDeadline(dl); err != nil {
|
||||
closeConn(err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
n, err := conn.Write(c.label)
|
||||
if err != nil {
|
||||
closeConn(err)
|
||||
return nil, err
|
||||
}
|
||||
if n != len(c.label) {
|
||||
closeConn(fmt.Errorf("short label write"))
|
||||
return nil, io.ErrShortWrite
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func MuxConnecter(rawConnecter transport.Connecter, labels []string, timeout time.Duration) (map[string]transport.Connecter, error) {
|
||||
ret := make(map[string]transport.Connecter, len(labels))
|
||||
for _, label := range labels {
|
||||
var paddedLabel [LabelLen]byte
|
||||
if err := padLabel(paddedLabel[:], label); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lc := &labeledConnecter{paddedLabel[:], rawConnecter}
|
||||
if _, ok := ret[label]; ok {
|
||||
return nil, fmt.Errorf("duplicate label %q", label)
|
||||
}
|
||||
ret[label] = lc
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
181
rpc/versionhandshake/versionhandshake.go
Normal file
181
rpc/versionhandshake/versionhandshake.go
Normal file
@ -0,0 +1,181 @@
|
||||
// Package versionhandshake wraps a transport.{Connecter,AuthenticatedListener}
|
||||
// to add an exchange of protocol version information on connection establishment.
|
||||
//
|
||||
// The protocol version information (banner) is plain text, thus making it
|
||||
// easy to diagnose issues with standard tools.
|
||||
package versionhandshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type HandshakeMessage struct {
|
||||
ProtocolVersion int
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
// A HandshakeError describes what went wrong during the handshake.
|
||||
// It implements net.Error and is always temporary.
|
||||
type HandshakeError struct {
|
||||
msg string
|
||||
// If not nil, the underlying IO error that caused the handshake to fail.
|
||||
IOError error
|
||||
}
|
||||
|
||||
var _ net.Error = &HandshakeError{}
|
||||
|
||||
func (e HandshakeError) Error() string { return e.msg }
|
||||
|
||||
// Always true to enable usage in a net.Listener.
|
||||
func (e HandshakeError) Temporary() bool { return true }
|
||||
|
||||
// If the underlying IOError was net.Error.Timeout(), Timeout() returns that value.
|
||||
// Otherwise false.
|
||||
func (e HandshakeError) Timeout() bool {
|
||||
if neterr, ok := e.IOError.(net.Error); ok {
|
||||
return neterr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hsErr(format string, args... interface{}) *HandshakeError {
|
||||
return &HandshakeError{msg: fmt.Sprintf(format, args...)}
|
||||
}
|
||||
|
||||
func hsIOErr(err error, format string, args... interface{}) *HandshakeError {
|
||||
return &HandshakeError{IOError: err, msg: fmt.Sprintf(format, args...)}
|
||||
}
|
||||
|
||||
// MaxProtocolVersion is the maximum allowed protocol version.
|
||||
// This is a protocol constant, changing it may break the wire format.
|
||||
const MaxProtocolVersion = 9999
|
||||
|
||||
// Only returns *HandshakeError as error.
|
||||
func (m *HandshakeMessage) Encode() ([]byte, error) {
|
||||
if m.ProtocolVersion <= 0 || m.ProtocolVersion > MaxProtocolVersion {
|
||||
return nil, hsErr(fmt.Sprintf("protocol version must be in [1, %d]", MaxProtocolVersion))
|
||||
}
|
||||
if len(m.Extensions) >= MaxProtocolVersion {
|
||||
return nil, hsErr(fmt.Sprintf("protocol only supports [0, %d] extensions", MaxProtocolVersion))
|
||||
}
|
||||
// EXTENSIONS is a count of subsequent \n separated lines that contain protocol extensions
|
||||
var extensions strings.Builder
|
||||
for i, ext := range m.Extensions {
|
||||
if strings.ContainsAny(ext, "\n") {
|
||||
return nil, hsErr("Extension #%d contains forbidden newline character", i)
|
||||
}
|
||||
if !utf8.ValidString(ext) {
|
||||
return nil, hsErr("Extension #%d is not valid UTF-8", i)
|
||||
}
|
||||
extensions.WriteString(ext)
|
||||
extensions.WriteString("\n")
|
||||
}
|
||||
withoutLen := fmt.Sprintf("ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%04d\n%s",
|
||||
m.ProtocolVersion, len(m.Extensions), extensions.String())
|
||||
withLen := fmt.Sprintf("%010d %s", len(withoutLen), withoutLen)
|
||||
return []byte(withLen), nil
|
||||
}
|
||||
|
||||
func (m *HandshakeMessage) DecodeReader(r io.Reader, maxLen int) error {
|
||||
var lenAndSpace [11]byte
|
||||
if _, err := io.ReadFull(r, lenAndSpace[:]); err != nil {
|
||||
return hsIOErr(err, "error reading protocol banner length: %s", err)
|
||||
}
|
||||
if !utf8.Valid(lenAndSpace[:]) {
|
||||
return hsErr("invalid start of handshake message: not valid UTF-8")
|
||||
}
|
||||
var followLen int
|
||||
n, err := fmt.Sscanf(string(lenAndSpace[:]), "%010d ", &followLen)
|
||||
if n != 1 || err != nil {
|
||||
return hsErr("could not parse handshake message length")
|
||||
}
|
||||
if followLen > maxLen {
|
||||
return hsErr("handshake message length exceeds max length (%d vs %d)",
|
||||
followLen, maxLen)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err = io.Copy(&buf, io.LimitReader(r, int64(followLen)))
|
||||
if err != nil {
|
||||
return hsIOErr(err, "error reading protocol banner body: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
protoVersion, extensionCount int
|
||||
)
|
||||
n, err = fmt.Fscanf(&buf, "ZREPL_ZFS_REPLICATION PROTOVERSION=%04d EXTENSIONS=%4d\n",
|
||||
&protoVersion, &extensionCount)
|
||||
if n != 2 || err != nil {
|
||||
return hsErr("could not parse handshake message: %s", err)
|
||||
}
|
||||
if protoVersion < 1 {
|
||||
return hsErr("invalid protocol version %q", protoVersion)
|
||||
}
|
||||
m.ProtocolVersion = protoVersion
|
||||
|
||||
if extensionCount < 0 {
|
||||
return hsErr("invalid extension count %q", extensionCount)
|
||||
}
|
||||
if extensionCount == 0 {
|
||||
if buf.Len() != 0 {
|
||||
return hsErr("unexpected data trailing after header")
|
||||
}
|
||||
m.Extensions = nil
|
||||
return nil
|
||||
}
|
||||
s := buf.String()
|
||||
if strings.Count(s, "\n") != extensionCount {
|
||||
return hsErr("inconsistent extension count: found %d, header says %d", len(m.Extensions), extensionCount)
|
||||
}
|
||||
exts := strings.Split(s, "\n")
|
||||
if exts[len(exts)-1] != "" {
|
||||
return hsErr("unexpected data trailing after last extension newline")
|
||||
}
|
||||
m.Extensions = exts[0:len(exts)-1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoHandshakeCurrentVersion(conn net.Conn, deadline time.Time) error {
|
||||
// current protocol version is hardcoded here
|
||||
return DoHandshakeVersion(conn, deadline, 1)
|
||||
}
|
||||
|
||||
const HandshakeMessageMaxLen = 16 * 4096
|
||||
|
||||
func DoHandshakeVersion(conn net.Conn, deadline time.Time, version int) error {
|
||||
ours := HandshakeMessage{
|
||||
ProtocolVersion: version,
|
||||
Extensions: nil,
|
||||
}
|
||||
hsb, err := ours.Encode()
|
||||
if err != nil {
|
||||
return hsErr("could not encode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
conn.SetDeadline(deadline)
|
||||
_, err = io.Copy(conn, bytes.NewBuffer(hsb))
|
||||
if err != nil {
|
||||
return hsErr("could not send protocol banner: %s", err)
|
||||
}
|
||||
|
||||
theirs := HandshakeMessage{}
|
||||
if err := theirs.DecodeReader(conn, HandshakeMessageMaxLen); err != nil {
|
||||
return hsErr("could not decode protocol banner: %s", err)
|
||||
}
|
||||
|
||||
if theirs.ProtocolVersion != ours.ProtocolVersion {
|
||||
return hsErr("protocol versions do not match: ours is %d, theirs is %d",
|
||||
ours.ProtocolVersion, theirs.ProtocolVersion)
|
||||
}
|
||||
// ignore extensions, we don't use them
|
||||
|
||||
return nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package transport
|
||||
package versionhandshake
|
||||
|
||||
import (
|
||||
"bytes"
|
66
rpc/versionhandshake/versionhandshake_transport_wrappers.go
Normal file
66
rpc/versionhandshake/versionhandshake_transport_wrappers.go
Normal file
@ -0,0 +1,66 @@
|
||||
package versionhandshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type HandshakeConnecter struct {
|
||||
connecter transport.Connecter
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (c HandshakeConnecter) Connect(ctx context.Context) (transport.Wire, error) {
|
||||
conn, err := c.connecter.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(c.timeout)
|
||||
}
|
||||
if err := DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func Connecter(connecter transport.Connecter, timeout time.Duration) HandshakeConnecter {
|
||||
return HandshakeConnecter{
|
||||
connecter: connecter,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// wrapper type that performs a a protocol version handshake before returning the connection
|
||||
type HandshakeListener struct {
|
||||
l transport.AuthenticatedListener
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (l HandshakeListener) Addr() (net.Addr) { return l.l.Addr() }
|
||||
|
||||
func (l HandshakeListener) Close() error { return l.l.Close() }
|
||||
|
||||
func (l HandshakeListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
conn, err := l.l.Accept(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
dl = time.Now().Add(l.timeout) // shadowing
|
||||
}
|
||||
if err := DoHandshakeCurrentVersion(conn, dl); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func Listener(l transport.AuthenticatedListener, timeout time.Duration) transport.AuthenticatedListener {
|
||||
return HandshakeListener{l, timeout}
|
||||
}
|
@ -4,8 +4,11 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -22,12 +25,13 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
|
||||
}
|
||||
|
||||
type ClientAuthListener struct {
|
||||
l net.Listener
|
||||
l *net.TCPListener
|
||||
c *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewClientAuthListener(
|
||||
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
|
||||
l *net.TCPListener, ca *x509.CertPool, serverCert tls.Certificate,
|
||||
handshakeTimeout time.Duration) *ClientAuthListener {
|
||||
|
||||
if ca == nil {
|
||||
@ -37,29 +41,35 @@ func NewClientAuthListener(
|
||||
panic(serverCert)
|
||||
}
|
||||
|
||||
tlsConf := tls.Config{
|
||||
tlsConf := &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
ClientCAs: ca,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
PreferServerCipherSuites: true,
|
||||
KeyLogWriter: keylogFromEnv(),
|
||||
}
|
||||
l = tls.NewListener(l, &tlsConf)
|
||||
return &ClientAuthListener{
|
||||
l,
|
||||
tlsConf,
|
||||
handshakeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
|
||||
c, err = l.l.Accept()
|
||||
// Accept() accepts a connection from the *net.TCPListener passed to the constructor
|
||||
// and sets up the TLS connection, including handshake and peer CommmonName validation
|
||||
// within the specified handshakeTimeout.
|
||||
//
|
||||
// It returns both the raw TCP connection (tcpConn) and the TLS connection (tlsConn) on top of it.
|
||||
// Access to the raw tcpConn might be necessary if CloseWrite semantics are desired:
|
||||
// tlsConn.CloseWrite does NOT call tcpConn.CloseWrite, hence we provide access to tcpConn to
|
||||
// allow the caller to do this by themselves.
|
||||
func (l *ClientAuthListener) Accept() (tcpConn *net.TCPConn, tlsConn *tls.Conn, clientCN string, err error) {
|
||||
tcpConn, err = l.l.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
tlsConn, ok := c.(*tls.Conn)
|
||||
if !ok {
|
||||
return c, "", err
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
tlsConn = tls.Server(tcpConn, l.c)
|
||||
var (
|
||||
cn string
|
||||
peerCerts []*x509.Certificate
|
||||
@ -70,6 +80,7 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
|
||||
if err = tlsConn.Handshake(); err != nil {
|
||||
goto CloseAndErr
|
||||
}
|
||||
tlsConn.SetDeadline(time.Time{})
|
||||
|
||||
peerCerts = tlsConn.ConnectionState().PeerCertificates
|
||||
if len(peerCerts) < 1 {
|
||||
@ -77,10 +88,11 @@ func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
|
||||
goto CloseAndErr
|
||||
}
|
||||
cn = peerCerts[0].Subject.CommonName
|
||||
return c, cn, nil
|
||||
return tcpConn, tlsConn, cn, nil
|
||||
CloseAndErr:
|
||||
c.Close()
|
||||
return nil, "", err
|
||||
// unlike CloseWrite, Close on *tls.Conn actually closes the underlying connection
|
||||
tlsConn.Close() // TODO log error
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
func (l *ClientAuthListener) Addr() net.Addr {
|
||||
@ -105,7 +117,21 @@ func ClientAuthClient(serverName string, rootCA *x509.CertPool, clientCert tls.C
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
RootCAs: rootCA,
|
||||
ServerName: serverName,
|
||||
KeyLogWriter: keylogFromEnv(),
|
||||
}
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func keylogFromEnv() io.Writer {
|
||||
var keyLog io.Writer = nil
|
||||
if outfile := os.Getenv("ZREPL_KEYLOG_FILE"); outfile != "" {
|
||||
fmt.Fprintf(os.Stderr, "writing to key log %s\n", outfile)
|
||||
var err error
|
||||
keyLog, err = os.OpenFile(outfile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return keyLog
|
||||
}
|
||||
|
58
transport/fromconfig/transport_fromconfig.go
Normal file
58
transport/fromconfig/transport_fromconfig.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Package fromconfig instantiates transports based on zrepl config structures
|
||||
// (see package config).
|
||||
package fromconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/transport/local"
|
||||
"github.com/zrepl/zrepl/transport/ssh"
|
||||
"github.com/zrepl/zrepl/transport/tcp"
|
||||
"github.com/zrepl/zrepl/transport/tls"
|
||||
)
|
||||
|
||||
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory,error) {
|
||||
|
||||
var (
|
||||
l transport.AuthenticatedListenerFactory
|
||||
err error
|
||||
)
|
||||
switch v := in.Ret.(type) {
|
||||
case *config.TCPServe:
|
||||
l, err = tcp.TCPListenerFactoryFromConfig(g, v)
|
||||
case *config.TLSServe:
|
||||
l, err = tls.TLSListenerFactoryFromConfig(g, v)
|
||||
case *config.StdinserverServer:
|
||||
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
|
||||
case *config.LocalServe:
|
||||
l, err = local.LocalListenerFactoryFromConfig(g, v)
|
||||
default:
|
||||
return nil, errors.Errorf("internal error: unknown serve type %T", v)
|
||||
}
|
||||
|
||||
return l, err
|
||||
}
|
||||
|
||||
|
||||
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) {
|
||||
var (
|
||||
connecter transport.Connecter
|
||||
err error
|
||||
)
|
||||
switch v := in.Ret.(type) {
|
||||
case *config.SSHStdinserverConnect:
|
||||
connecter, err = ssh.SSHStdinserverConnecterFromConfig(v)
|
||||
case *config.TCPConnect:
|
||||
connecter, err = tcp.TCPConnecterFromConfig(v)
|
||||
case *config.TLSConnect:
|
||||
connecter, err = tls.TLSConnecterFromConfig(v)
|
||||
case *config.LocalConnect:
|
||||
connecter, err = local.LocalConnecterFromConfig(v)
|
||||
default:
|
||||
panic(fmt.Sprintf("implementation error: unknown connecter type %T", v))
|
||||
}
|
||||
|
||||
return connecter, err
|
||||
}
|
@ -1,11 +1,10 @@
|
||||
package connecter
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/transport/serve"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type LocalConnecter struct {
|
||||
@ -23,8 +22,8 @@ func LocalConnecterFromConfig(in *config.LocalConnect) (*LocalConnecter, error)
|
||||
return &LocalConnecter{listenerName: in.ListenerName, clientIdentity: in.ClientIdentity}, nil
|
||||
}
|
||||
|
||||
func (c *LocalConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) {
|
||||
l := serve.GetLocalListener(c.listenerName)
|
||||
func (c *LocalConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
|
||||
l := GetLocalListener(c.listenerName)
|
||||
return l.Connect(dialCtx, c.clientIdentity)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package serve
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/zrepl/zrepl/util/socketpair"
|
||||
"net"
|
||||
"sync"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
var localListeners struct {
|
||||
@ -39,7 +40,7 @@ type connectRequest struct {
|
||||
}
|
||||
|
||||
type connectResult struct {
|
||||
conn net.Conn
|
||||
conn transport.Wire
|
||||
err error
|
||||
}
|
||||
|
||||
@ -54,7 +55,7 @@ func newLocalListener() *LocalListener {
|
||||
}
|
||||
|
||||
// Connect to the LocalListener from a client with identity clientIdentity
|
||||
func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn net.Conn, err error) {
|
||||
func (l *LocalListener) Connect(dialCtx context.Context, clientIdentity string) (conn transport.Wire, err error) {
|
||||
|
||||
// place request
|
||||
req := connectRequest{
|
||||
@ -89,21 +90,14 @@ func (a localAddr) String() string { return a.S }
|
||||
|
||||
func (l *LocalListener) Addr() (net.Addr) { return localAddr{"<listening>"} }
|
||||
|
||||
type localConn struct {
|
||||
net.Conn
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (l localConn) ClientIdentity() string { return l.clientIdentity }
|
||||
|
||||
func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
func (l *LocalListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
respondToRequest := func(req connectRequest, res connectResult) (err error) {
|
||||
getLogger(ctx).
|
||||
transport.GetLogger(ctx).
|
||||
WithField("res.conn", res.conn).WithField("res.err", res.err).
|
||||
Debug("responding to client request")
|
||||
defer func() {
|
||||
errv := recover()
|
||||
getLogger(ctx).WithField("recover_err", errv).
|
||||
transport.GetLogger(ctx).WithField("recover_err", errv).
|
||||
Debug("panic on send to client callback, likely a legitimate client-side timeout")
|
||||
}()
|
||||
select {
|
||||
@ -116,7 +110,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
getLogger(ctx).Debug("waiting for local client connect requests")
|
||||
transport.GetLogger(ctx).Debug("waiting for local client connect requests")
|
||||
var req connectRequest
|
||||
select {
|
||||
case req = <-l.connects:
|
||||
@ -124,7 +118,7 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
getLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request")
|
||||
transport.GetLogger(ctx).WithField("client_identity", req.clientIdentity).Debug("got connect request")
|
||||
if req.clientIdentity == "" {
|
||||
res := connectResult{nil, fmt.Errorf("client identity must not be empty")}
|
||||
if err := respondToRequest(req, res); err != nil {
|
||||
@ -133,31 +127,31 @@ func (l *LocalListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
return nil, fmt.Errorf("client connected with empty client identity")
|
||||
}
|
||||
|
||||
getLogger(ctx).Debug("creating socketpair")
|
||||
transport.GetLogger(ctx).Debug("creating socketpair")
|
||||
left, right, err := socketpair.SocketPair()
|
||||
if err != nil {
|
||||
res := connectResult{nil, fmt.Errorf("server error: %s", err)}
|
||||
if respErr := respondToRequest(req, res); respErr != nil {
|
||||
// returning the socketpair error properly is more important than the error sent to the client
|
||||
getLogger(ctx).WithError(respErr).Error("error responding to client")
|
||||
transport.GetLogger(ctx).WithError(respErr).Error("error responding to client")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getLogger(ctx).Debug("responding with left side of socketpair")
|
||||
transport.GetLogger(ctx).Debug("responding with left side of socketpair")
|
||||
res := connectResult{left, nil}
|
||||
if err := respondToRequest(req, res); err != nil {
|
||||
getLogger(ctx).WithError(err).Error("error responding to client")
|
||||
transport.GetLogger(ctx).WithError(err).Error("error responding to client")
|
||||
if err := left.Close(); err != nil {
|
||||
getLogger(ctx).WithError(err).Error("cannot close left side of socketpair")
|
||||
transport.GetLogger(ctx).WithError(err).Error("cannot close left side of socketpair")
|
||||
}
|
||||
if err := right.Close(); err != nil {
|
||||
getLogger(ctx).WithError(err).Error("cannot close right side of socketpair")
|
||||
transport.GetLogger(ctx).WithError(err).Error("cannot close right side of socketpair")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return localConn{right, req.clientIdentity}, nil
|
||||
return transport.NewAuthConn(right, req.clientIdentity), nil
|
||||
}
|
||||
|
||||
func (l *LocalListener) Close() error {
|
||||
@ -169,19 +163,13 @@ func (l *LocalListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type LocalListenerFactory struct {
|
||||
listenerName string
|
||||
}
|
||||
|
||||
func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (f *LocalListenerFactory, err error) {
|
||||
func LocalListenerFactoryFromConfig(g *config.Global, in *config.LocalServe) (transport.AuthenticatedListenerFactory,error) {
|
||||
if in.ListenerName == "" {
|
||||
return nil, fmt.Errorf("ListenerName must not be empty")
|
||||
}
|
||||
return &LocalListenerFactory{listenerName: in.ListenerName}, nil
|
||||
listenerName := in.ListenerName
|
||||
lf := func() (transport.AuthenticatedListener,error) {
|
||||
return GetLocalListener(listenerName), nil
|
||||
}
|
||||
|
||||
|
||||
func (lf *LocalListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
return GetLocalListener(lf.listenerName), nil
|
||||
return lf, nil
|
||||
}
|
||||
|
@ -1,13 +1,12 @@
|
||||
package connecter
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-netssh"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -22,8 +21,6 @@ type SSHStdinserverConnecter struct {
|
||||
dialTimeout time.Duration
|
||||
}
|
||||
|
||||
var _ streamrpc.Connecter = &SSHStdinserverConnecter{}
|
||||
|
||||
func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSHStdinserverConnecter, err error) {
|
||||
|
||||
c = &SSHStdinserverConnecter{
|
||||
@ -39,15 +36,7 @@ func SSHStdinserverConnecterFromConfig(in *config.SSHStdinserverConnect) (c *SSH
|
||||
|
||||
}
|
||||
|
||||
type netsshConnToConn struct{ *netssh.SSHConn }
|
||||
|
||||
var _ net.Conn = netsshConnToConn{}
|
||||
|
||||
func (netsshConnToConn) SetDeadline(dl time.Time) error { return nil }
|
||||
func (netsshConnToConn) SetReadDeadline(dl time.Time) error { return nil }
|
||||
func (netsshConnToConn) SetWriteDeadline(dl time.Time) error { return nil }
|
||||
|
||||
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, error) {
|
||||
func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
|
||||
|
||||
var endpoint netssh.Endpoint
|
||||
if err := copier.Copy(&endpoint, c); err != nil {
|
||||
@ -62,5 +51,5 @@ func (c *SSHStdinserverConnecter) Connect(dialCtx context.Context) (net.Conn, er
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return netsshConnToConn{nconn}, nil
|
||||
return nconn, nil
|
||||
}
|
@ -1,50 +1,38 @@
|
||||
package serve
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"github.com/problame/go-netssh"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/daemon/nethelpers"
|
||||
"io"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"time"
|
||||
"context"
|
||||
"github.com/pkg/errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentities []string
|
||||
Sockdir string
|
||||
}
|
||||
|
||||
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
|
||||
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (transport.AuthenticatedListenerFactory,error) {
|
||||
|
||||
for _, ci := range in.ClientIdentities {
|
||||
if err := ValidateClientIdentity(ci); err != nil {
|
||||
if err := transport.ValidateClientIdentity(ci); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
|
||||
}
|
||||
}
|
||||
|
||||
f = &multiStdinserverListenerFactory{
|
||||
ClientIdentities: in.ClientIdentities,
|
||||
Sockdir: g.Serve.StdinServer.SockDir,
|
||||
clientIdentities := in.ClientIdentities
|
||||
sockdir := g.Serve.StdinServer.SockDir
|
||||
|
||||
lf := func() (transport.AuthenticatedListener,error) {
|
||||
return multiStdinserverListenerFromClientIdentities(sockdir, clientIdentities)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type multiStdinserverListenerFactory struct {
|
||||
ClientIdentities []string
|
||||
Sockdir string
|
||||
}
|
||||
|
||||
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
type multiStdinserverAcceptRes struct {
|
||||
conn AuthenticatedConn
|
||||
conn *transport.AuthConn
|
||||
err error
|
||||
}
|
||||
|
||||
@ -78,7 +66,7 @@ func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string)
|
||||
return &MultiStdinserverListener{listeners: listeners}, nil
|
||||
}
|
||||
|
||||
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
|
||||
func (m *MultiStdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error){
|
||||
|
||||
if m.accepts == nil {
|
||||
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
|
||||
@ -97,8 +85,22 @@ func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedCon
|
||||
|
||||
}
|
||||
|
||||
func (m *MultiStdinserverListener) Addr() (net.Addr) {
|
||||
return netsshAddr{}
|
||||
type multiListenerAddr struct {
|
||||
clients []string
|
||||
}
|
||||
|
||||
func (multiListenerAddr) Network() string { return "netssh" }
|
||||
|
||||
func (l multiListenerAddr) String() string {
|
||||
return fmt.Sprintf("netssh:clients=%v", l.clients)
|
||||
}
|
||||
|
||||
func (m *MultiStdinserverListener) Addr() net.Addr {
|
||||
cis := make([]string, len(m.listeners))
|
||||
for i := range cis {
|
||||
cis[i] = m.listeners[i].clientIdentity
|
||||
}
|
||||
return multiListenerAddr{cis}
|
||||
}
|
||||
|
||||
func (m *MultiStdinserverListener) Close() error {
|
||||
@ -118,41 +120,28 @@ type stdinserverListener struct {
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Addr() net.Addr {
|
||||
return netsshAddr{}
|
||||
type listenerAddr struct {
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
func (listenerAddr) Network() string { return "netssh" }
|
||||
|
||||
func (a listenerAddr) String() string {
|
||||
return fmt.Sprintf("netssh:client=%q", a.clientIdentity)
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Addr() net.Addr {
|
||||
return listenerAddr{l.clientIdentity}
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
|
||||
return transport.NewAuthConn(c, l.clientIdentity), nil
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Close() (err error) {
|
||||
return l.l.Close()
|
||||
}
|
||||
|
||||
type netsshAddr struct{}
|
||||
|
||||
func (netsshAddr) Network() string { return "netssh" }
|
||||
func (netsshAddr) String() string { return "???" }
|
||||
|
||||
type netsshConnToNetConnAdatper struct {
|
||||
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
|
||||
|
||||
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
// FIXME log warning once!
|
||||
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetWriteDeadline(t time.Time) error { return nil }
|
@ -1,9 +1,11 @@
|
||||
package connecter
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"net"
|
||||
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type TCPConnecter struct {
|
||||
@ -19,6 +21,10 @@ func TCPConnecterFromConfig(in *config.TCPConnect) (*TCPConnecter, error) {
|
||||
return &TCPConnecter{in.Address, dialer}, nil
|
||||
}
|
||||
|
||||
func (c *TCPConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) {
|
||||
return c.dialer.DialContext(dialCtx, "tcp", c.Address)
|
||||
func (c *TCPConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
|
||||
conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(*net.TCPConn), nil
|
||||
}
|
@ -1,17 +1,13 @@
|
||||
package serve
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"net"
|
||||
"github.com/pkg/errors"
|
||||
"context"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type TCPListenerFactory struct {
|
||||
address *net.TCPAddr
|
||||
clientMap *ipMap
|
||||
}
|
||||
|
||||
type ipMapEntry struct {
|
||||
ip net.IP
|
||||
ident string
|
||||
@ -28,7 +24,7 @@ func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
|
||||
if clientIP == nil {
|
||||
return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
|
||||
}
|
||||
if err := ValidateClientIdentity(clientIdent); err != nil {
|
||||
if err := transport.ValidateClientIdentity(clientIdent); err != nil {
|
||||
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
|
||||
}
|
||||
entries = append(entries, ipMapEntry{clientIP, clientIdent})
|
||||
@ -45,7 +41,7 @@ func (m *ipMap) Get(ip net.IP) (string, error) {
|
||||
return "", errors.Errorf("no identity mapping for client IP %s", ip)
|
||||
}
|
||||
|
||||
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
|
||||
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (transport.AuthenticatedListenerFactory, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", in.Listen)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse listen address")
|
||||
@ -54,19 +50,14 @@ func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPLi
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse client IP map")
|
||||
}
|
||||
lf := &TCPListenerFactory{
|
||||
address: addr,
|
||||
clientMap: clientMap,
|
||||
}
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := net.ListenTCP("tcp", f.address)
|
||||
lf := func() (transport.AuthenticatedListener, error) {
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPAuthListener{l, f.clientMap}, nil
|
||||
return &TCPAuthListener{l, clientMap}, nil
|
||||
}
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
type TCPAuthListener struct {
|
||||
@ -74,18 +65,18 @@ type TCPAuthListener struct {
|
||||
clientMap *ipMap
|
||||
}
|
||||
|
||||
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
nc, err := f.TCPListener.Accept()
|
||||
func (f *TCPAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
nc, err := f.TCPListener.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
|
||||
clientIdent, err := f.clientMap.Get(clientIP)
|
||||
if err != nil {
|
||||
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
|
||||
transport.GetLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
|
||||
nc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return authConn{nc, clientIdent}, nil
|
||||
return transport.NewAuthConn(nc, clientIdent), nil
|
||||
}
|
||||
|
@ -1,12 +1,14 @@
|
||||
package connecter
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
)
|
||||
|
||||
type TLSConnecter struct {
|
||||
@ -38,10 +40,12 @@ func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) {
|
||||
return &TLSConnecter{in.Address, dialer, tlsConfig}, nil
|
||||
}
|
||||
|
||||
func (c *TLSConnecter) Connect(dialCtx context.Context) (conn net.Conn, err error) {
|
||||
conn, err = c.dialer.DialContext(dialCtx, "tcp", c.Address)
|
||||
func (c *TLSConnecter) Connect(dialCtx context.Context) (transport.Wire, error) {
|
||||
conn, err := c.dialer.DialContext(dialCtx, "tcp", c.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tls.Client(conn, c.tlsConfig), nil
|
||||
tcpConn := conn.(*net.TCPConn)
|
||||
tlsConn := tls.Client(conn, c.tlsConfig)
|
||||
return newWireAdaptor(tlsConn, tcpConn), nil
|
||||
}
|
89
transport/tls/serve_tls.go
Normal file
89
transport/tls/serve_tls.go
Normal file
@ -0,0 +1,89 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"net"
|
||||
"time"
|
||||
"context"
|
||||
)
|
||||
|
||||
type TLSListenerFactory struct {
|
||||
address string
|
||||
clientCA *x509.CertPool
|
||||
serverCert tls.Certificate
|
||||
handshakeTimeout time.Duration
|
||||
clientCNs map[string]struct{}
|
||||
}
|
||||
|
||||
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory,error) {
|
||||
|
||||
address := in.Listen
|
||||
handshakeTimeout := in.HandshakeTimeout
|
||||
|
||||
if in.Ca == "" || in.Cert == "" || in.Key == "" {
|
||||
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
||||
}
|
||||
|
||||
clientCA, err := tlsconf.ParseCAFile(in.Ca)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse ca file")
|
||||
}
|
||||
|
||||
serverCert, err := tls.LoadX509KeyPair(in.Cert, in.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse cer/key pair")
|
||||
}
|
||||
|
||||
clientCNs := make(map[string]struct{}, len(in.ClientCNs))
|
||||
for i, cn := range in.ClientCNs {
|
||||
if err := transport.ValidateClientIdentity(cn); err != nil {
|
||||
return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn)
|
||||
}
|
||||
// dupes are ok fr now
|
||||
clientCNs[cn] = struct{}{}
|
||||
}
|
||||
|
||||
lf := func() (transport.AuthenticatedListener, error) {
|
||||
l, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpL := l.(*net.TCPListener)
|
||||
tl := tlsconf.NewClientAuthListener(tcpL, clientCA, serverCert, handshakeTimeout)
|
||||
return &tlsAuthListener{tl, clientCNs}, nil
|
||||
}
|
||||
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
type tlsAuthListener struct {
|
||||
*tlsconf.ClientAuthListener
|
||||
clientCNs map[string]struct{}
|
||||
}
|
||||
|
||||
func (l tlsAuthListener) Accept(ctx context.Context) (*transport.AuthConn, error) {
|
||||
tcpConn, tlsConn, cn, err := l.ClientAuthListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := l.clientCNs[cn]; !ok {
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
defer tlsConn.SetDeadline(time.Time{})
|
||||
tlsConn.SetDeadline(dl)
|
||||
}
|
||||
if err := tlsConn.Close(); err != nil {
|
||||
transport.GetLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name")
|
||||
}
|
||||
return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, tlsConn.RemoteAddr())
|
||||
}
|
||||
adaptor := newWireAdaptor(tlsConn, tcpConn)
|
||||
return transport.NewAuthConn(adaptor, cn), nil
|
||||
}
|
||||
|
||||
|
47
transport/tls/tls_wire_adaptor.go
Normal file
47
transport/tls/tls_wire_adaptor.go
Normal file
@ -0,0 +1,47 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
// adapts a tls.Conn and its underlying net.TCPConn into a valid transport.Wire
|
||||
type transportWireAdaptor struct {
|
||||
*tls.Conn
|
||||
tcpConn *net.TCPConn
|
||||
}
|
||||
|
||||
func newWireAdaptor(tlsConn *tls.Conn, tcpConn *net.TCPConn) transportWireAdaptor {
|
||||
return transportWireAdaptor{tlsConn, tcpConn}
|
||||
}
|
||||
|
||||
// CloseWrite implements transport.Wire.CloseWrite which is different from *tls.Conn.CloseWrite:
|
||||
// the former requires that the other side observes io.EOF, but *tls.Conn.CloseWrite does not
|
||||
// close the underlying connection so no io.EOF would be observed.
|
||||
func (w transportWireAdaptor) CloseWrite() error {
|
||||
if err := w.Conn.CloseWrite(); err != nil {
|
||||
// TODO log error
|
||||
fmt.Fprintf(os.Stderr, "transport/tls.CloseWrite() error: %s\n", err)
|
||||
}
|
||||
return w.tcpConn.CloseWrite()
|
||||
}
|
||||
|
||||
// Close implements transport.Wire.Close which is different from a *tls.Conn.Close:
|
||||
// At the time of writing (Go 1.11), closing tls.Conn closes the TCP connection immediately,
|
||||
// which results in io.ErrUnexpectedEOF on the other side.
|
||||
// We assume that w.Conn has a deadline set for the close, so the CloseWrite will time out if it blocks,
|
||||
// falling through to the actual Close()
|
||||
func (w transportWireAdaptor) Close() error {
|
||||
// var buf [1<<15]byte
|
||||
// w.Conn.Write(buf[:])
|
||||
// CloseWrite will send a TLS alert record down the line which
|
||||
// in the Go implementation acts like a flush...?
|
||||
// if err := w.Conn.CloseWrite(); err != nil {
|
||||
// // TODO log error
|
||||
// fmt.Fprintf(os.Stderr, "transport/tls.Close() close write error: %s\n", err)
|
||||
// }
|
||||
// time.Sleep(1 * time.Second)
|
||||
return w.Conn.Close()
|
||||
}
|
84
transport/transport.go
Normal file
84
transport/transport.go
Normal file
@ -0,0 +1,84 @@
|
||||
// Package transport defines a common interface for
|
||||
// network connections that have an associated client identity.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type AuthConn struct {
|
||||
Wire
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
var _ timeoutconn.SyscallConner = AuthConn{}
|
||||
|
||||
func (a AuthConn) SyscallConn() (rawConn syscall.RawConn, err error) {
|
||||
scc, ok := a.Wire.(timeoutconn.SyscallConner)
|
||||
if !ok {
|
||||
return nil, timeoutconn.SyscallConnNotSupported
|
||||
}
|
||||
return scc.SyscallConn()
|
||||
}
|
||||
|
||||
func NewAuthConn(conn Wire, clientIdentity string) *AuthConn {
|
||||
return &AuthConn{conn, clientIdentity}
|
||||
}
|
||||
|
||||
func (c *AuthConn) ClientIdentity() string {
|
||||
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c.clientIdentity
|
||||
}
|
||||
|
||||
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
||||
type AuthenticatedListener interface {
|
||||
Addr() net.Addr
|
||||
Accept(ctx context.Context) (*AuthConn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type AuthenticatedListenerFactory func() (AuthenticatedListener, error)
|
||||
|
||||
type Wire = timeoutconn.Wire
|
||||
|
||||
type Connecter interface {
|
||||
Connect(ctx context.Context) (Wire, error)
|
||||
}
|
||||
|
||||
// A client identity must be a single component in a ZFS filesystem path
|
||||
func ValidateClientIdentity(in string) (err error) {
|
||||
path, err := zfs.NewDatasetPath(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if path.Length() != 1 {
|
||||
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type contextKey int
|
||||
|
||||
const contextKeyLog contextKey = 0
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLog, log)
|
||||
}
|
||||
|
||||
func GetLogger(ctx context.Context) Logger {
|
||||
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
||||
return log
|
||||
}
|
||||
return logger.NewNullLogger()
|
||||
}
|
71
util/bytecounter/bytecounter_streamcopier.go
Normal file
71
util/bytecounter/bytecounter_streamcopier.go
Normal file
@ -0,0 +1,71 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// StreamCopier wraps a zfs.StreamCopier, reimplemening
|
||||
// its interface and counting the bytes written to during copying.
|
||||
type StreamCopier interface {
|
||||
zfs.StreamCopier
|
||||
Count() int64
|
||||
}
|
||||
|
||||
// NewStreamCopier wraps sc into a StreamCopier.
|
||||
// If sc is io.Reader, it is guaranteed that the returned StreamCopier
|
||||
// implements that interface, too.
|
||||
func NewStreamCopier(sc zfs.StreamCopier) StreamCopier {
|
||||
bsc := &streamCopier{sc, 0}
|
||||
if scr, ok := sc.(io.Reader); ok {
|
||||
return streamCopierAndReader{bsc, scr}
|
||||
} else {
|
||||
return bsc
|
||||
}
|
||||
}
|
||||
|
||||
type streamCopier struct {
|
||||
sc zfs.StreamCopier
|
||||
count int64
|
||||
}
|
||||
|
||||
// proxy writer used by streamCopier
|
||||
type streamCopierWriter struct {
|
||||
parent *streamCopier
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (w streamCopierWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = w.w.Write(p)
|
||||
atomic.AddInt64(&w.parent.count, int64(n))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamCopier) Count() int64 {
|
||||
return atomic.LoadInt64(&s.count)
|
||||
}
|
||||
|
||||
var _ zfs.StreamCopier = &streamCopier{}
|
||||
|
||||
func (s streamCopier) Close() error {
|
||||
return s.sc.Close()
|
||||
}
|
||||
|
||||
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
ww := streamCopierWriter{s, w}
|
||||
return s.sc.WriteStreamTo(ww)
|
||||
}
|
||||
|
||||
// a streamCopier whose underlying sc is an io.Reader
|
||||
type streamCopierAndReader struct {
|
||||
*streamCopier
|
||||
asReader io.Reader
|
||||
}
|
||||
|
||||
func (scr streamCopierAndReader) Read(p []byte) (int, error) {
|
||||
n, err := scr.asReader.Read(p)
|
||||
atomic.AddInt64(&scr.streamCopier.count, int64(n))
|
||||
return n, err
|
||||
}
|
38
util/bytecounter/bytecounter_streamcopier_test.go
Normal file
38
util/bytecounter/bytecounter_streamcopier_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type mockStreamCopierAndReader struct {
|
||||
zfs.StreamCopier // to satisfy interface
|
||||
reads int
|
||||
}
|
||||
|
||||
func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) {
|
||||
r.reads++
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
var _ io.Reader = &mockStreamCopierAndReader{}
|
||||
|
||||
func TestNewStreamCopierReexportsReader(t *testing.T) {
|
||||
mock := &mockStreamCopierAndReader{}
|
||||
x := NewStreamCopier(mock)
|
||||
|
||||
r, ok := x.(io.Reader)
|
||||
if !ok {
|
||||
t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x)
|
||||
}
|
||||
|
||||
var buf [23]byte
|
||||
n, err := r.Read(buf[:])
|
||||
assert.True(t, mock.reads == 1)
|
||||
assert.True(t, n == len(buf))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, x.Count() == 23)
|
||||
}
|
14
util/devnoop/devnoop.go
Normal file
14
util/devnoop/devnoop.go
Normal file
@ -0,0 +1,14 @@
|
||||
// package devnoop provides an io.ReadWriteCloser that never errors
|
||||
// and always reports reads / writes to / from buffers as complete.
|
||||
// The buffers themselves are never touched.
|
||||
package devnoop
|
||||
|
||||
type Dev struct{}
|
||||
|
||||
func Get() Dev {
|
||||
return Dev{}
|
||||
}
|
||||
|
||||
func (Dev) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (Dev) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (Dev) Close() error { return nil }
|
@ -2,6 +2,7 @@ package envconst
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -23,3 +24,19 @@ func Duration(varname string, def time.Duration) time.Duration {
|
||||
cache.Store(varname, d)
|
||||
return d
|
||||
}
|
||||
|
||||
func Int64(varname string, def int64) int64 {
|
||||
if v, ok := cache.Load(varname); ok {
|
||||
return v.(int64)
|
||||
}
|
||||
e := os.Getenv(varname)
|
||||
if e == "" {
|
||||
return def
|
||||
}
|
||||
d, err := strconv.ParseInt(e, 10, 64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cache.Store(varname, d)
|
||||
return d
|
||||
}
|
||||
|
@ -4,15 +4,18 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface.
|
||||
type IOCommand struct {
|
||||
Cmd *exec.Cmd
|
||||
kill context.CancelFunc
|
||||
Stdin io.WriteCloser
|
||||
Stdout io.ReadCloser
|
||||
StderrBuf *bytes.Buffer
|
||||
@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS
|
||||
|
||||
c = &IOCommand{}
|
||||
|
||||
ctx, c.kill = context.WithCancel(ctx)
|
||||
c.Cmd = exec.CommandContext(ctx, command, args...)
|
||||
|
||||
if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil {
|
||||
@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) {
|
||||
func (c *IOCommand) Read(buf []byte) (n int, err error) {
|
||||
n, err = c.Stdout.Read(buf)
|
||||
if err == io.EOF {
|
||||
if waitErr := c.doWait(); waitErr != nil {
|
||||
if waitErr := c.doWait(context.Background()); waitErr != nil {
|
||||
err = waitErr
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *IOCommand) doWait() (err error) {
|
||||
func (c *IOCommand) doWait(ctx context.Context) (err error) {
|
||||
go func() {
|
||||
dl, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
time.Sleep(dl.Sub(time.Now()))
|
||||
c.kill()
|
||||
c.Stdout.Close()
|
||||
c.Stdin.Close()
|
||||
}()
|
||||
waitErr := c.Cmd.Wait()
|
||||
var wasUs bool = false
|
||||
var waitStatus syscall.WaitStatus
|
||||
@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) {
|
||||
if c.Cmd.ProcessState == nil {
|
||||
// racy...
|
||||
err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.doWait()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second))
|
||||
defer cancel()
|
||||
return c.doWait(ctx)
|
||||
} else {
|
||||
return c.ExitResult.Error
|
||||
}
|
||||
|
@ -1,42 +1,32 @@
|
||||
package socketpair
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
type fileConn struct {
|
||||
net.Conn // net.FileConn
|
||||
f *os.File
|
||||
}
|
||||
|
||||
func (c fileConn) Close() error {
|
||||
if err := c.Conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func SocketPair() (a, b net.Conn, err error) {
|
||||
func SocketPair() (a, b *net.UnixConn, err error) {
|
||||
// don't use net.Pipe, as it doesn't implement things like lingering, which our code relies on
|
||||
sockpair, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
toConn := func(fd int) (net.Conn, error) {
|
||||
toConn := func(fd int) (*net.UnixConn, error) {
|
||||
f := os.NewFile(uintptr(fd), "fileconn")
|
||||
if f == nil {
|
||||
panic(fd)
|
||||
}
|
||||
c, err := net.FileConn(f)
|
||||
f.Close() // net.FileConn uses dup under the hood
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return fileConn{Conn: c, f: f}, nil
|
||||
// strictly, the following type assertion is an implementation detail
|
||||
// however, will be caught by test TestSocketPairWorks
|
||||
fileConnIsUnixConn := c.(*net.UnixConn)
|
||||
return fileConnIsUnixConn, nil
|
||||
}
|
||||
if a, err = toConn(sockpair[0]); err != nil { // shadowing
|
||||
return nil, nil, err
|
||||
|
18
util/socketpair/socketpair_test.go
Normal file
18
util/socketpair/socketpair_test.go
Normal file
@ -0,0 +1,18 @@
|
||||
package socketpair
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// This is test is mostly to verify that the assumption about
|
||||
// net.FileConn returning *net.UnixConn for AF_UNIX FDs works.
|
||||
func TestSocketPairWorks(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
a, b, err := SocketPair()
|
||||
assert.NoError(t, err)
|
||||
a.Close()
|
||||
b.Close()
|
||||
})
|
||||
}
|
@ -274,7 +274,7 @@ func ZFSCreatePlaceholderFilesystem(p *DatasetPath) (err error) {
|
||||
}
|
||||
|
||||
if err = cmd.Wait(); err != nil {
|
||||
err = ZFSError{
|
||||
err = &ZFSError{
|
||||
Stderr: stderr.Bytes(),
|
||||
WaitErr: err,
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user