Add --skip-cert-check flag to zrepl configcheck to prevent checking cert files

It may be desirable to check that a config is valid without checking for
the existence of certificate files (e.g. when validating a config inside
a sandbox without access to the cert files).

This will be very useful for NixOS so that we can check the config file
at nix-build time (e.g. potentially without proper permissions to read cert
files for a TLS connection).

fixes https://github.com/zrepl/zrepl/issues/467
closes https://github.com/zrepl/zrepl/pull/587
This commit is contained in:
Cole Helbling 2022-03-29 19:39:10 -07:00 committed by Christian Schwarz
parent e4112d888c
commit 1df0f8912a
13 changed files with 46 additions and 41 deletions

View File

@ -19,8 +19,9 @@ import (
)
var configcheckArgs struct {
format string
what string
format string
what string
skipCertCheck bool
}
var ConfigcheckCmd = &cli.Subcommand{
@ -29,6 +30,7 @@ var ConfigcheckCmd = &cli.Subcommand{
SetupFlags: func(f *pflag.FlagSet) {
f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]")
f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]")
f.BoolVar(&configcheckArgs.skipCertCheck, "skip-cert-check", false, "skip checking cert files")
},
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
formatMap := map[string]func(interface{}){
@ -56,8 +58,16 @@ var ConfigcheckCmd = &cli.Subcommand{
}
var hadErr bool
parseFlags := config.ParseFlagsNone
if configcheckArgs.skipCertCheck {
parseFlags |= config.ParseFlagsNoCertCheck
}
// further: try to build jobs
confJobs, err := job.JobsFromConfig(subcommand.Config())
confJobs, err := job.JobsFromConfig(subcommand.Config(), parseFlags)
if err != nil {
err := errors.Wrap(err, "cannot build jobs from config")
if configcheckArgs.what == "jobs" {

View File

@ -129,7 +129,7 @@ func doMigrateReplicationCursor(ctx context.Context, sc *cli.Subcommand, args []
}
cfg := sc.Config()
jobs, err := job.JobsFromConfig(cfg)
jobs, err := job.JobsFromConfig(cfg, config.ParseFlagsNone)
if err != nil {
fmt.Printf("cannot parse config:\n%s\n\n", err)
fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n")

View File

@ -17,6 +17,13 @@ import (
zfsprop "github.com/zrepl/zrepl/zfs/property"
)
type ParseFlags uint
const (
ParseFlagsNone ParseFlags = 0
ParseFlagsNoCertCheck ParseFlags = 1 << iota
)
type Config struct {
Jobs []JobEnum `yaml:"jobs"`
Global *Global `yaml:"global,optional,fromdefaults"`

View File

@ -51,7 +51,7 @@ func Run(ctx context.Context, conf *config.Config) error {
}
outlets.Add(newPrometheusLogOutlet(), logger.Debug)
confJobs, err := job.JobsFromConfig(conf)
confJobs, err := job.JobsFromConfig(conf, config.ParseFlagsNone)
if err != nil {
return errors.Wrap(err, "cannot build jobs from config")
}

View File

@ -300,7 +300,7 @@ func replicationDriverConfigFromConfig(in *config.Replication) (c driver.Config,
return c, err
}
func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (j *ActiveSide, err error) {
func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}, parseFlags config.ParseFlags) (j *ActiveSide, err error) {
j = &ActiveSide{}
j.name, err = endpoint.MakeJobID(in.Name)
@ -349,7 +349,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (
ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()},
})
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect)
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect, parseFlags)
if err != nil {
return nil, errors.Wrap(err, "cannot build client")
}

View File

@ -11,10 +11,10 @@ import (
"github.com/zrepl/zrepl/util/bandwidthlimit"
)
func JobsFromConfig(c *config.Config) ([]Job, error) {
func JobsFromConfig(c *config.Config, parseFlags config.ParseFlags) ([]Job, error) {
js := make([]Job, len(c.Jobs))
for i := range c.Jobs {
j, err := buildJob(c.Global, c.Jobs[i])
j, err := buildJob(c.Global, c.Jobs[i], parseFlags)
if err != nil {
return nil, err
}
@ -42,19 +42,19 @@ func JobsFromConfig(c *config.Config) ([]Job, error) {
return js, nil
}
func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) {
func buildJob(c *config.Global, in config.JobEnum, parseFlags config.ParseFlags) (j Job, err error) {
cannotBuildJob := func(e error, name string) (Job, error) {
return nil, errors.Wrapf(e, "cannot build job %q", name)
}
// FIXME prettify this
switch v := in.Ret.(type) {
case *config.SinkJob:
j, err = passiveSideFromConfig(c, &v.PassiveJob, v)
j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
if err != nil {
return cannotBuildJob(err, v.Name)
}
case *config.SourceJob:
j, err = passiveSideFromConfig(c, &v.PassiveJob, v)
j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
if err != nil {
return cannotBuildJob(err, v.Name)
}
@ -64,12 +64,12 @@ func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) {
return cannotBuildJob(err, v.Name)
}
case *config.PushJob:
j, err = activeSide(c, &v.ActiveJob, v)
j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
if err != nil {
return cannotBuildJob(err, v.Name)
}
case *config.PullJob:
j, err = activeSide(c, &v.ActiveJob, v)
j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
if err != nil {
return cannotBuildJob(err, v.Name)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/transport/tls"
)
func TestValidateReceivingSidesDoNotOverlap(t *testing.T) {
@ -96,7 +95,7 @@ jobs:
conf, err := config.ParseConfigBytes([]byte(fill(c.jobName)))
require.NoError(t, err, "not expecting yaml-config to know about job ids")
require.NotNil(t, conf)
jobs, err := JobsFromConfig(conf)
jobs, err := JobsFromConfig(conf, config.ParseFlagsNone)
if c.valid {
assert.NoError(t, err)
@ -153,8 +152,7 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
t.Logf("file: %s", p)
t.Log(pretty.Sprint(c))
tls.FakeCertificateLoading(t)
jobs, err := JobsFromConfig(c)
jobs, err := JobsFromConfig(c, config.ParseFlagsNoCertCheck)
t.Logf("jobs: %#v", jobs)
require.NoError(t, err)
@ -299,7 +297,7 @@ jobs:
t.Logf("testing config:\n%s", cstr)
c, err := config.ParseConfigBytes([]byte(cstr))
require.NoError(t, err)
jobs, err := JobsFromConfig(c)
jobs, err := JobsFromConfig(c, config.ParseFlagsNone)
if ts.expectOk != nil {
require.NoError(t, err)
require.NotNil(t, c)

View File

@ -91,7 +91,7 @@ func (m *modeSource) SnapperReport() *snapper.Report {
return m.snapper.Report()
}
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}) (s *PassiveSide, err error) {
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}, parseFlags config.ParseFlags) (s *PassiveSide, err error) {
s = &PassiveSide{}
@ -110,7 +110,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob in
return nil, err // no wrapping necessary
}
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil {
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve, parseFlags); err != nil {
return nil, errors.Wrap(err, "cannot build listener factory")
}

View File

@ -67,7 +67,7 @@ func main() {
case "connect":
tc, err := getTestCase(args.testCase)
noerror(err)
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect)
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect, config.ParseFlagsNone)
noerror(err)
wire, err := connecter.Connect(ctx)
noerror(err)
@ -75,7 +75,7 @@ func main() {
case "serve":
tc, err := getTestCase(args.testCase)
noerror(err)
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve)
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve, config.ParseFlagsNone)
noerror(err)
l, err := lf()
noerror(err)

View File

@ -15,7 +15,7 @@ import (
"github.com/zrepl/zrepl/transport/tls"
)
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory, error) {
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
var (
l transport.AuthenticatedListenerFactory
@ -25,7 +25,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
case *config.TCPServe:
l, err = tcp.TCPListenerFactoryFromConfig(g, v)
case *config.TLSServe:
l, err = tls.TLSListenerFactoryFromConfig(g, v)
l, err = tls.TLSListenerFactoryFromConfig(g, v, parseFlags)
case *config.StdinserverServer:
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
case *config.LocalServe:
@ -37,7 +37,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
return l, err
}
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) {
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum, parseFlags config.ParseFlags) (transport.Connecter, error) {
var (
connecter transport.Connecter
err error
@ -48,7 +48,7 @@ func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Con
case *config.TCPConnect:
connecter, err = tcp.TCPConnecterFromConfig(v)
case *config.TLSConnect:
connecter, err = tls.TLSConnecterFromConfig(v)
connecter, err = tls.TLSConnecterFromConfig(v, parseFlags)
case *config.LocalConnect:
connecter, err = local.LocalConnecterFromConfig(v)
default:

View File

@ -18,12 +18,12 @@ type TLSConnecter struct {
tlsConfig *tls.Config
}
func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) {
func TLSConnecterFromConfig(in *config.TLSConnect, parseFlags config.ParseFlags) (*TLSConnecter, error) {
dialer := net.Dialer{
Timeout: in.DialTimeout,
}
if fakeCertificateLoading {
if parseFlags&config.ParseFlagsNoCertCheck != 0 {
return &TLSConnecter{in.Address, dialer, nil}, nil
}

View File

@ -16,7 +16,7 @@ import (
type TLSListenerFactory struct{}
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory, error) {
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
address := in.Listen
handshakeTimeout := in.HandshakeTimeout
@ -25,7 +25,7 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transp
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
}
if fakeCertificateLoading {
if parseFlags&config.ParseFlagsNoCertCheck != 0 {
return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil
}

View File

@ -1,10 +0,0 @@
package tls
import "testing"
var fakeCertificateLoading bool
func FakeCertificateLoading(t *testing.T) {
t.Logf("faking certificate loading")
fakeCertificateLoading = true
}