Add --skip-cert-check flag to zrepl configcheck to prevent checking cert files

It may be desirable to check that a config is valid without checking for
the existence of certificate files (e.g. when validating a config inside
a sandbox without access to the cert files).

This will be very useful for NixOS so that we can check the config file
at nix-build time (e.g. potentially without proper permissions to read cert
files for a TLS connection).

fixes https://github.com/zrepl/zrepl/issues/467
closes https://github.com/zrepl/zrepl/pull/587
This commit is contained in:
Cole Helbling 2022-03-29 19:39:10 -07:00 committed by Christian Schwarz
parent e4112d888c
commit 1df0f8912a
13 changed files with 46 additions and 41 deletions

View File

@ -21,6 +21,7 @@ import (
var configcheckArgs struct { var configcheckArgs struct {
format string format string
what string what string
skipCertCheck bool
} }
var ConfigcheckCmd = &cli.Subcommand{ var ConfigcheckCmd = &cli.Subcommand{
@ -29,6 +30,7 @@ var ConfigcheckCmd = &cli.Subcommand{
SetupFlags: func(f *pflag.FlagSet) { SetupFlags: func(f *pflag.FlagSet) {
f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]") f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]")
f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]") f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]")
f.BoolVar(&configcheckArgs.skipCertCheck, "skip-cert-check", false, "skip checking cert files")
}, },
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
formatMap := map[string]func(interface{}){ formatMap := map[string]func(interface{}){
@ -56,8 +58,16 @@ var ConfigcheckCmd = &cli.Subcommand{
} }
var hadErr bool var hadErr bool
parseFlags := config.ParseFlagsNone
if configcheckArgs.skipCertCheck {
parseFlags |= config.ParseFlagsNoCertCheck
}
// further: try to build jobs // further: try to build jobs
confJobs, err := job.JobsFromConfig(subcommand.Config()) confJobs, err := job.JobsFromConfig(subcommand.Config(), parseFlags)
if err != nil { if err != nil {
err := errors.Wrap(err, "cannot build jobs from config") err := errors.Wrap(err, "cannot build jobs from config")
if configcheckArgs.what == "jobs" { if configcheckArgs.what == "jobs" {

View File

@ -129,7 +129,7 @@ func doMigrateReplicationCursor(ctx context.Context, sc *cli.Subcommand, args []
} }
cfg := sc.Config() cfg := sc.Config()
jobs, err := job.JobsFromConfig(cfg) jobs, err := job.JobsFromConfig(cfg, config.ParseFlagsNone)
if err != nil { if err != nil {
fmt.Printf("cannot parse config:\n%s\n\n", err) fmt.Printf("cannot parse config:\n%s\n\n", err)
fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n") fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n")

View File

@ -17,6 +17,13 @@ import (
zfsprop "github.com/zrepl/zrepl/zfs/property" zfsprop "github.com/zrepl/zrepl/zfs/property"
) )
type ParseFlags uint
const (
ParseFlagsNone ParseFlags = 0
ParseFlagsNoCertCheck ParseFlags = 1 << iota
)
type Config struct { type Config struct {
Jobs []JobEnum `yaml:"jobs"` Jobs []JobEnum `yaml:"jobs"`
Global *Global `yaml:"global,optional,fromdefaults"` Global *Global `yaml:"global,optional,fromdefaults"`

View File

@ -51,7 +51,7 @@ func Run(ctx context.Context, conf *config.Config) error {
} }
outlets.Add(newPrometheusLogOutlet(), logger.Debug) outlets.Add(newPrometheusLogOutlet(), logger.Debug)
confJobs, err := job.JobsFromConfig(conf) confJobs, err := job.JobsFromConfig(conf, config.ParseFlagsNone)
if err != nil { if err != nil {
return errors.Wrap(err, "cannot build jobs from config") return errors.Wrap(err, "cannot build jobs from config")
} }

View File

@ -300,7 +300,7 @@ func replicationDriverConfigFromConfig(in *config.Replication) (c driver.Config,
return c, err return c, err
} }
func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (j *ActiveSide, err error) { func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}, parseFlags config.ParseFlags) (j *ActiveSide, err error) {
j = &ActiveSide{} j = &ActiveSide{}
j.name, err = endpoint.MakeJobID(in.Name) j.name, err = endpoint.MakeJobID(in.Name)
@ -349,7 +349,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (
ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()}, ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()},
}) })
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect) j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect, parseFlags)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot build client") return nil, errors.Wrap(err, "cannot build client")
} }

View File

@ -11,10 +11,10 @@ import (
"github.com/zrepl/zrepl/util/bandwidthlimit" "github.com/zrepl/zrepl/util/bandwidthlimit"
) )
func JobsFromConfig(c *config.Config) ([]Job, error) { func JobsFromConfig(c *config.Config, parseFlags config.ParseFlags) ([]Job, error) {
js := make([]Job, len(c.Jobs)) js := make([]Job, len(c.Jobs))
for i := range c.Jobs { for i := range c.Jobs {
j, err := buildJob(c.Global, c.Jobs[i]) j, err := buildJob(c.Global, c.Jobs[i], parseFlags)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,19 +42,19 @@ func JobsFromConfig(c *config.Config) ([]Job, error) {
return js, nil return js, nil
} }
func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) { func buildJob(c *config.Global, in config.JobEnum, parseFlags config.ParseFlags) (j Job, err error) {
cannotBuildJob := func(e error, name string) (Job, error) { cannotBuildJob := func(e error, name string) (Job, error) {
return nil, errors.Wrapf(e, "cannot build job %q", name) return nil, errors.Wrapf(e, "cannot build job %q", name)
} }
// FIXME prettify this // FIXME prettify this
switch v := in.Ret.(type) { switch v := in.Ret.(type) {
case *config.SinkJob: case *config.SinkJob:
j, err = passiveSideFromConfig(c, &v.PassiveJob, v) j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
if err != nil { if err != nil {
return cannotBuildJob(err, v.Name) return cannotBuildJob(err, v.Name)
} }
case *config.SourceJob: case *config.SourceJob:
j, err = passiveSideFromConfig(c, &v.PassiveJob, v) j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
if err != nil { if err != nil {
return cannotBuildJob(err, v.Name) return cannotBuildJob(err, v.Name)
} }
@ -64,12 +64,12 @@ func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) {
return cannotBuildJob(err, v.Name) return cannotBuildJob(err, v.Name)
} }
case *config.PushJob: case *config.PushJob:
j, err = activeSide(c, &v.ActiveJob, v) j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
if err != nil { if err != nil {
return cannotBuildJob(err, v.Name) return cannotBuildJob(err, v.Name)
} }
case *config.PullJob: case *config.PullJob:
j, err = activeSide(c, &v.ActiveJob, v) j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
if err != nil { if err != nil {
return cannotBuildJob(err, v.Name) return cannotBuildJob(err, v.Name)
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport/tls"
) )
func TestValidateReceivingSidesDoNotOverlap(t *testing.T) { func TestValidateReceivingSidesDoNotOverlap(t *testing.T) {
@ -96,7 +95,7 @@ jobs:
conf, err := config.ParseConfigBytes([]byte(fill(c.jobName))) conf, err := config.ParseConfigBytes([]byte(fill(c.jobName)))
require.NoError(t, err, "not expecting yaml-config to know about job ids") require.NoError(t, err, "not expecting yaml-config to know about job ids")
require.NotNil(t, conf) require.NotNil(t, conf)
jobs, err := JobsFromConfig(conf) jobs, err := JobsFromConfig(conf, config.ParseFlagsNone)
if c.valid { if c.valid {
assert.NoError(t, err) assert.NoError(t, err)
@ -153,8 +152,7 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
t.Logf("file: %s", p) t.Logf("file: %s", p)
t.Log(pretty.Sprint(c)) t.Log(pretty.Sprint(c))
tls.FakeCertificateLoading(t) jobs, err := JobsFromConfig(c, config.ParseFlagsNoCertCheck)
jobs, err := JobsFromConfig(c)
t.Logf("jobs: %#v", jobs) t.Logf("jobs: %#v", jobs)
require.NoError(t, err) require.NoError(t, err)
@ -299,7 +297,7 @@ jobs:
t.Logf("testing config:\n%s", cstr) t.Logf("testing config:\n%s", cstr)
c, err := config.ParseConfigBytes([]byte(cstr)) c, err := config.ParseConfigBytes([]byte(cstr))
require.NoError(t, err) require.NoError(t, err)
jobs, err := JobsFromConfig(c) jobs, err := JobsFromConfig(c, config.ParseFlagsNone)
if ts.expectOk != nil { if ts.expectOk != nil {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, c) require.NotNil(t, c)

View File

@ -91,7 +91,7 @@ func (m *modeSource) SnapperReport() *snapper.Report {
return m.snapper.Report() return m.snapper.Report()
} }
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}) (s *PassiveSide, err error) { func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}, parseFlags config.ParseFlags) (s *PassiveSide, err error) {
s = &PassiveSide{} s = &PassiveSide{}
@ -110,7 +110,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob in
return nil, err // no wrapping necessary return nil, err // no wrapping necessary
} }
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil { if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve, parseFlags); err != nil {
return nil, errors.Wrap(err, "cannot build listener factory") return nil, errors.Wrap(err, "cannot build listener factory")
} }

View File

@ -67,7 +67,7 @@ func main() {
case "connect": case "connect":
tc, err := getTestCase(args.testCase) tc, err := getTestCase(args.testCase)
noerror(err) noerror(err)
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect) connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect, config.ParseFlagsNone)
noerror(err) noerror(err)
wire, err := connecter.Connect(ctx) wire, err := connecter.Connect(ctx)
noerror(err) noerror(err)
@ -75,7 +75,7 @@ func main() {
case "serve": case "serve":
tc, err := getTestCase(args.testCase) tc, err := getTestCase(args.testCase)
noerror(err) noerror(err)
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve) lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve, config.ParseFlagsNone)
noerror(err) noerror(err)
l, err := lf() l, err := lf()
noerror(err) noerror(err)

View File

@ -15,7 +15,7 @@ import (
"github.com/zrepl/zrepl/transport/tls" "github.com/zrepl/zrepl/transport/tls"
) )
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory, error) { func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
var ( var (
l transport.AuthenticatedListenerFactory l transport.AuthenticatedListenerFactory
@ -25,7 +25,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
case *config.TCPServe: case *config.TCPServe:
l, err = tcp.TCPListenerFactoryFromConfig(g, v) l, err = tcp.TCPListenerFactoryFromConfig(g, v)
case *config.TLSServe: case *config.TLSServe:
l, err = tls.TLSListenerFactoryFromConfig(g, v) l, err = tls.TLSListenerFactoryFromConfig(g, v, parseFlags)
case *config.StdinserverServer: case *config.StdinserverServer:
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v) l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
case *config.LocalServe: case *config.LocalServe:
@ -37,7 +37,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
return l, err return l, err
} }
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) { func ConnecterFromConfig(g *config.Global, in config.ConnectEnum, parseFlags config.ParseFlags) (transport.Connecter, error) {
var ( var (
connecter transport.Connecter connecter transport.Connecter
err error err error
@ -48,7 +48,7 @@ func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Con
case *config.TCPConnect: case *config.TCPConnect:
connecter, err = tcp.TCPConnecterFromConfig(v) connecter, err = tcp.TCPConnecterFromConfig(v)
case *config.TLSConnect: case *config.TLSConnect:
connecter, err = tls.TLSConnecterFromConfig(v) connecter, err = tls.TLSConnecterFromConfig(v, parseFlags)
case *config.LocalConnect: case *config.LocalConnect:
connecter, err = local.LocalConnecterFromConfig(v) connecter, err = local.LocalConnecterFromConfig(v)
default: default:

View File

@ -18,12 +18,12 @@ type TLSConnecter struct {
tlsConfig *tls.Config tlsConfig *tls.Config
} }
func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) { func TLSConnecterFromConfig(in *config.TLSConnect, parseFlags config.ParseFlags) (*TLSConnecter, error) {
dialer := net.Dialer{ dialer := net.Dialer{
Timeout: in.DialTimeout, Timeout: in.DialTimeout,
} }
if fakeCertificateLoading { if parseFlags&config.ParseFlagsNoCertCheck != 0 {
return &TLSConnecter{in.Address, dialer, nil}, nil return &TLSConnecter{in.Address, dialer, nil}, nil
} }

View File

@ -16,7 +16,7 @@ import (
type TLSListenerFactory struct{} type TLSListenerFactory struct{}
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory, error) { func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
address := in.Listen address := in.Listen
handshakeTimeout := in.HandshakeTimeout handshakeTimeout := in.HandshakeTimeout
@ -25,7 +25,7 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transp
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
} }
if fakeCertificateLoading { if parseFlags&config.ParseFlagsNoCertCheck != 0 {
return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil
} }

View File

@ -1,10 +0,0 @@
package tls
import "testing"
var fakeCertificateLoading bool
func FakeCertificateLoading(t *testing.T) {
t.Logf("faking certificate loading")
fakeCertificateLoading = true
}