diff --git a/client/configcheck.go b/client/configcheck.go index 4da99c3..15b9b4a 100644 --- a/client/configcheck.go +++ b/client/configcheck.go @@ -19,8 +19,9 @@ import ( ) var configcheckArgs struct { - format string - what string + format string + what string + skipCertCheck bool } var ConfigcheckCmd = &cli.Subcommand{ @@ -29,6 +30,7 @@ var ConfigcheckCmd = &cli.Subcommand{ SetupFlags: func(f *pflag.FlagSet) { f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]") f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]") + f.BoolVar(&configcheckArgs.skipCertCheck, "skip-cert-check", false, "skip checking cert files") }, Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { formatMap := map[string]func(interface{}){ @@ -56,8 +58,16 @@ var ConfigcheckCmd = &cli.Subcommand{ } var hadErr bool + + parseFlags := config.ParseFlagsNone + + if configcheckArgs.skipCertCheck { + parseFlags |= config.ParseFlagsNoCertCheck + } + // further: try to build jobs - confJobs, err := job.JobsFromConfig(subcommand.Config()) + confJobs, err := job.JobsFromConfig(subcommand.Config(), parseFlags) + if err != nil { err := errors.Wrap(err, "cannot build jobs from config") if configcheckArgs.what == "jobs" { diff --git a/client/migrate.go b/client/migrate.go index b73b47c..eddb787 100644 --- a/client/migrate.go +++ b/client/migrate.go @@ -129,7 +129,7 @@ func doMigrateReplicationCursor(ctx context.Context, sc *cli.Subcommand, args [] } cfg := sc.Config() - jobs, err := job.JobsFromConfig(cfg) + jobs, err := job.JobsFromConfig(cfg, config.ParseFlagsNone) if err != nil { fmt.Printf("cannot parse config:\n%s\n\n", err) fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n") diff --git a/config/config.go b/config/config.go index 17263be..18983da 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,13 @@ import ( zfsprop "github.com/zrepl/zrepl/zfs/property" ) +type ParseFlags uint + +const ( + ParseFlagsNone ParseFlags = 0 + ParseFlagsNoCertCheck ParseFlags = 1 << iota +) + type Config struct { Jobs []JobEnum `yaml:"jobs"` Global *Global `yaml:"global,optional,fromdefaults"` diff --git a/daemon/daemon.go b/daemon/daemon.go index c5399f1..75333e0 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -51,7 +51,7 @@ func Run(ctx context.Context, conf *config.Config) error { } outlets.Add(newPrometheusLogOutlet(), logger.Debug) - confJobs, err := job.JobsFromConfig(conf) + confJobs, err := job.JobsFromConfig(conf, config.ParseFlagsNone) if err != nil { return errors.Wrap(err, "cannot build jobs from config") } diff --git a/daemon/job/active.go b/daemon/job/active.go index b540789..4eb2962 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -300,7 +300,7 @@ func replicationDriverConfigFromConfig(in *config.Replication) (c driver.Config, return c, err } -func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (j *ActiveSide, err error) { +func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}, parseFlags config.ParseFlags) (j *ActiveSide, err error) { j = &ActiveSide{} j.name, err = endpoint.MakeJobID(in.Name) @@ -349,7 +349,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) ( ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()}, }) - j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect) + j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect, parseFlags) if err != nil { return nil, errors.Wrap(err, "cannot build client") } diff --git a/daemon/job/build_jobs.go b/daemon/job/build_jobs.go index 9d5b950..be630e9 100644 --- a/daemon/job/build_jobs.go +++ b/daemon/job/build_jobs.go @@ -11,10 +11,10 @@ import ( "github.com/zrepl/zrepl/util/bandwidthlimit" ) -func JobsFromConfig(c *config.Config) ([]Job, error) { +func JobsFromConfig(c *config.Config, parseFlags config.ParseFlags) ([]Job, error) { js := make([]Job, len(c.Jobs)) for i := range c.Jobs { - j, err := buildJob(c.Global, c.Jobs[i]) + j, err := buildJob(c.Global, c.Jobs[i], parseFlags) if err != nil { return nil, err } @@ -42,19 +42,19 @@ func JobsFromConfig(c *config.Config) ([]Job, error) { return js, nil } -func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) { +func buildJob(c *config.Global, in config.JobEnum, parseFlags config.ParseFlags) (j Job, err error) { cannotBuildJob := func(e error, name string) (Job, error) { return nil, errors.Wrapf(e, "cannot build job %q", name) } // FIXME prettify this switch v := in.Ret.(type) { case *config.SinkJob: - j, err = passiveSideFromConfig(c, &v.PassiveJob, v) + j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags) if err != nil { return cannotBuildJob(err, v.Name) } case *config.SourceJob: - j, err = passiveSideFromConfig(c, &v.PassiveJob, v) + j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags) if err != nil { return cannotBuildJob(err, v.Name) } @@ -64,12 +64,12 @@ func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) { return cannotBuildJob(err, v.Name) } case *config.PushJob: - j, err = activeSide(c, &v.ActiveJob, v) + j, err = activeSide(c, &v.ActiveJob, v, parseFlags) if err != nil { return cannotBuildJob(err, v.Name) } case *config.PullJob: - j, err = activeSide(c, &v.ActiveJob, v) + j, err = activeSide(c, &v.ActiveJob, v, parseFlags) if err != nil { return cannotBuildJob(err, v.Name) } diff --git a/daemon/job/build_jobs_test.go b/daemon/job/build_jobs_test.go index cc9a81a..ef0196a 100644 --- a/daemon/job/build_jobs_test.go +++ b/daemon/job/build_jobs_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/transport/tls" ) func TestValidateReceivingSidesDoNotOverlap(t *testing.T) { @@ -96,7 +95,7 @@ jobs: conf, err := config.ParseConfigBytes([]byte(fill(c.jobName))) require.NoError(t, err, "not expecting yaml-config to know about job ids") require.NotNil(t, conf) - jobs, err := JobsFromConfig(conf) + jobs, err := JobsFromConfig(conf, config.ParseFlagsNone) if c.valid { assert.NoError(t, err) @@ -153,8 +152,7 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) { t.Logf("file: %s", p) t.Log(pretty.Sprint(c)) - tls.FakeCertificateLoading(t) - jobs, err := JobsFromConfig(c) + jobs, err := JobsFromConfig(c, config.ParseFlagsNoCertCheck) t.Logf("jobs: %#v", jobs) require.NoError(t, err) @@ -299,7 +297,7 @@ jobs: t.Logf("testing config:\n%s", cstr) c, err := config.ParseConfigBytes([]byte(cstr)) require.NoError(t, err) - jobs, err := JobsFromConfig(c) + jobs, err := JobsFromConfig(c, config.ParseFlagsNone) if ts.expectOk != nil { require.NoError(t, err) require.NotNil(t, c) diff --git a/daemon/job/passive.go b/daemon/job/passive.go index 9414124..a72dd96 100644 --- a/daemon/job/passive.go +++ b/daemon/job/passive.go @@ -91,7 +91,7 @@ func (m *modeSource) SnapperReport() *snapper.Report { return m.snapper.Report() } -func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}) (s *PassiveSide, err error) { +func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}, parseFlags config.ParseFlags) (s *PassiveSide, err error) { s = &PassiveSide{} @@ -110,7 +110,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob in return nil, err // no wrapping necessary } - if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil { + if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve, parseFlags); err != nil { return nil, errors.Wrap(err, "cannot build listener factory") } diff --git a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go index e57c311..29b0462 100644 --- a/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go +++ b/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator.go @@ -67,7 +67,7 @@ func main() { case "connect": tc, err := getTestCase(args.testCase) noerror(err) - connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect) + connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect, config.ParseFlagsNone) noerror(err) wire, err := connecter.Connect(ctx) noerror(err) @@ -75,7 +75,7 @@ func main() { case "serve": tc, err := getTestCase(args.testCase) noerror(err) - lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve) + lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve, config.ParseFlagsNone) noerror(err) l, err := lf() noerror(err) diff --git a/transport/fromconfig/transport_fromconfig.go b/transport/fromconfig/transport_fromconfig.go index def2450..9ee144f 100644 --- a/transport/fromconfig/transport_fromconfig.go +++ b/transport/fromconfig/transport_fromconfig.go @@ -15,7 +15,7 @@ import ( "github.com/zrepl/zrepl/transport/tls" ) -func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory, error) { +func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) { var ( l transport.AuthenticatedListenerFactory @@ -25,7 +25,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport case *config.TCPServe: l, err = tcp.TCPListenerFactoryFromConfig(g, v) case *config.TLSServe: - l, err = tls.TLSListenerFactoryFromConfig(g, v) + l, err = tls.TLSListenerFactoryFromConfig(g, v, parseFlags) case *config.StdinserverServer: l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v) case *config.LocalServe: @@ -37,7 +37,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport return l, err } -func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) { +func ConnecterFromConfig(g *config.Global, in config.ConnectEnum, parseFlags config.ParseFlags) (transport.Connecter, error) { var ( connecter transport.Connecter err error @@ -48,7 +48,7 @@ func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Con case *config.TCPConnect: connecter, err = tcp.TCPConnecterFromConfig(v) case *config.TLSConnect: - connecter, err = tls.TLSConnecterFromConfig(v) + connecter, err = tls.TLSConnecterFromConfig(v, parseFlags) case *config.LocalConnect: connecter, err = local.LocalConnecterFromConfig(v) default: diff --git a/transport/tls/connect_tls.go b/transport/tls/connect_tls.go index 4fc76f9..abf2e58 100644 --- a/transport/tls/connect_tls.go +++ b/transport/tls/connect_tls.go @@ -18,12 +18,12 @@ type TLSConnecter struct { tlsConfig *tls.Config } -func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) { +func TLSConnecterFromConfig(in *config.TLSConnect, parseFlags config.ParseFlags) (*TLSConnecter, error) { dialer := net.Dialer{ Timeout: in.DialTimeout, } - if fakeCertificateLoading { + if parseFlags&config.ParseFlagsNoCertCheck != 0 { return &TLSConnecter{in.Address, dialer, nil}, nil } diff --git a/transport/tls/serve_tls.go b/transport/tls/serve_tls.go index dd156cc..2e60122 100644 --- a/transport/tls/serve_tls.go +++ b/transport/tls/serve_tls.go @@ -16,7 +16,7 @@ import ( type TLSListenerFactory struct{} -func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory, error) { +func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) { address := in.Listen handshakeTimeout := in.HandshakeTimeout @@ -25,7 +25,7 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transp return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") } - if fakeCertificateLoading { + if parseFlags&config.ParseFlagsNoCertCheck != 0 { return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil } diff --git a/transport/tls/tls_test_helper.go b/transport/tls/tls_test_helper.go deleted file mode 100644 index c778425..0000000 --- a/transport/tls/tls_test_helper.go +++ /dev/null @@ -1,10 +0,0 @@ -package tls - -import "testing" - -var fakeCertificateLoading bool - -func FakeCertificateLoading(t *testing.T) { - t.Logf("faking certificate loading") - fakeCertificateLoading = true -}