mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 08:23:50 +01:00
Add --skip-cert-check
flag to zrepl configcheck
to prevent checking cert files
It may be desirable to check that a config is valid without checking for the existence of certificate files (e.g. when validating a config inside a sandbox without access to the cert files). This will be very useful for NixOS so that we can check the config file at nix-build time (e.g. potentially without proper permissions to read cert files for a TLS connection). fixes https://github.com/zrepl/zrepl/issues/467 closes https://github.com/zrepl/zrepl/pull/587
This commit is contained in:
parent
e4112d888c
commit
1df0f8912a
@ -21,6 +21,7 @@ import (
|
|||||||
var configcheckArgs struct {
|
var configcheckArgs struct {
|
||||||
format string
|
format string
|
||||||
what string
|
what string
|
||||||
|
skipCertCheck bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var ConfigcheckCmd = &cli.Subcommand{
|
var ConfigcheckCmd = &cli.Subcommand{
|
||||||
@ -29,6 +30,7 @@ var ConfigcheckCmd = &cli.Subcommand{
|
|||||||
SetupFlags: func(f *pflag.FlagSet) {
|
SetupFlags: func(f *pflag.FlagSet) {
|
||||||
f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]")
|
f.StringVar(&configcheckArgs.format, "format", "", "dump parsed config object [pretty|yaml|json]")
|
||||||
f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]")
|
f.StringVar(&configcheckArgs.what, "what", "all", "what to print [all|config|jobs|logging]")
|
||||||
|
f.BoolVar(&configcheckArgs.skipCertCheck, "skip-cert-check", false, "skip checking cert files")
|
||||||
},
|
},
|
||||||
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
|
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
|
||||||
formatMap := map[string]func(interface{}){
|
formatMap := map[string]func(interface{}){
|
||||||
@ -56,8 +58,16 @@ var ConfigcheckCmd = &cli.Subcommand{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var hadErr bool
|
var hadErr bool
|
||||||
|
|
||||||
|
parseFlags := config.ParseFlagsNone
|
||||||
|
|
||||||
|
if configcheckArgs.skipCertCheck {
|
||||||
|
parseFlags |= config.ParseFlagsNoCertCheck
|
||||||
|
}
|
||||||
|
|
||||||
// further: try to build jobs
|
// further: try to build jobs
|
||||||
confJobs, err := job.JobsFromConfig(subcommand.Config())
|
confJobs, err := job.JobsFromConfig(subcommand.Config(), parseFlags)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Wrap(err, "cannot build jobs from config")
|
err := errors.Wrap(err, "cannot build jobs from config")
|
||||||
if configcheckArgs.what == "jobs" {
|
if configcheckArgs.what == "jobs" {
|
||||||
|
@ -129,7 +129,7 @@ func doMigrateReplicationCursor(ctx context.Context, sc *cli.Subcommand, args []
|
|||||||
}
|
}
|
||||||
|
|
||||||
cfg := sc.Config()
|
cfg := sc.Config()
|
||||||
jobs, err := job.JobsFromConfig(cfg)
|
jobs, err := job.JobsFromConfig(cfg, config.ParseFlagsNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("cannot parse config:\n%s\n\n", err)
|
fmt.Printf("cannot parse config:\n%s\n\n", err)
|
||||||
fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n")
|
fmt.Printf("NOTE: this migration was released together with a change in job name requirements.\n")
|
||||||
|
@ -17,6 +17,13 @@ import (
|
|||||||
zfsprop "github.com/zrepl/zrepl/zfs/property"
|
zfsprop "github.com/zrepl/zrepl/zfs/property"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ParseFlags uint
|
||||||
|
|
||||||
|
const (
|
||||||
|
ParseFlagsNone ParseFlags = 0
|
||||||
|
ParseFlagsNoCertCheck ParseFlags = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Jobs []JobEnum `yaml:"jobs"`
|
Jobs []JobEnum `yaml:"jobs"`
|
||||||
Global *Global `yaml:"global,optional,fromdefaults"`
|
Global *Global `yaml:"global,optional,fromdefaults"`
|
||||||
|
@ -51,7 +51,7 @@ func Run(ctx context.Context, conf *config.Config) error {
|
|||||||
}
|
}
|
||||||
outlets.Add(newPrometheusLogOutlet(), logger.Debug)
|
outlets.Add(newPrometheusLogOutlet(), logger.Debug)
|
||||||
|
|
||||||
confJobs, err := job.JobsFromConfig(conf)
|
confJobs, err := job.JobsFromConfig(conf, config.ParseFlagsNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "cannot build jobs from config")
|
return errors.Wrap(err, "cannot build jobs from config")
|
||||||
}
|
}
|
||||||
|
@ -300,7 +300,7 @@ func replicationDriverConfigFromConfig(in *config.Replication) (c driver.Config,
|
|||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (j *ActiveSide, err error) {
|
func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}, parseFlags config.ParseFlags) (j *ActiveSide, err error) {
|
||||||
|
|
||||||
j = &ActiveSide{}
|
j = &ActiveSide{}
|
||||||
j.name, err = endpoint.MakeJobID(in.Name)
|
j.name, err = endpoint.MakeJobID(in.Name)
|
||||||
@ -349,7 +349,7 @@ func activeSide(g *config.Global, in *config.ActiveJob, configJob interface{}) (
|
|||||||
ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()},
|
ConstLabels: prometheus.Labels{"zrepl_job": j.name.String()},
|
||||||
})
|
})
|
||||||
|
|
||||||
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect)
|
j.connecter, err = fromconfig.ConnecterFromConfig(g, in.Connect, parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot build client")
|
return nil, errors.Wrap(err, "cannot build client")
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@ import (
|
|||||||
"github.com/zrepl/zrepl/util/bandwidthlimit"
|
"github.com/zrepl/zrepl/util/bandwidthlimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
func JobsFromConfig(c *config.Config) ([]Job, error) {
|
func JobsFromConfig(c *config.Config, parseFlags config.ParseFlags) ([]Job, error) {
|
||||||
js := make([]Job, len(c.Jobs))
|
js := make([]Job, len(c.Jobs))
|
||||||
for i := range c.Jobs {
|
for i := range c.Jobs {
|
||||||
j, err := buildJob(c.Global, c.Jobs[i])
|
j, err := buildJob(c.Global, c.Jobs[i], parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -42,19 +42,19 @@ func JobsFromConfig(c *config.Config) ([]Job, error) {
|
|||||||
return js, nil
|
return js, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) {
|
func buildJob(c *config.Global, in config.JobEnum, parseFlags config.ParseFlags) (j Job, err error) {
|
||||||
cannotBuildJob := func(e error, name string) (Job, error) {
|
cannotBuildJob := func(e error, name string) (Job, error) {
|
||||||
return nil, errors.Wrapf(e, "cannot build job %q", name)
|
return nil, errors.Wrapf(e, "cannot build job %q", name)
|
||||||
}
|
}
|
||||||
// FIXME prettify this
|
// FIXME prettify this
|
||||||
switch v := in.Ret.(type) {
|
switch v := in.Ret.(type) {
|
||||||
case *config.SinkJob:
|
case *config.SinkJob:
|
||||||
j, err = passiveSideFromConfig(c, &v.PassiveJob, v)
|
j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cannotBuildJob(err, v.Name)
|
return cannotBuildJob(err, v.Name)
|
||||||
}
|
}
|
||||||
case *config.SourceJob:
|
case *config.SourceJob:
|
||||||
j, err = passiveSideFromConfig(c, &v.PassiveJob, v)
|
j, err = passiveSideFromConfig(c, &v.PassiveJob, v, parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cannotBuildJob(err, v.Name)
|
return cannotBuildJob(err, v.Name)
|
||||||
}
|
}
|
||||||
@ -64,12 +64,12 @@ func buildJob(c *config.Global, in config.JobEnum) (j Job, err error) {
|
|||||||
return cannotBuildJob(err, v.Name)
|
return cannotBuildJob(err, v.Name)
|
||||||
}
|
}
|
||||||
case *config.PushJob:
|
case *config.PushJob:
|
||||||
j, err = activeSide(c, &v.ActiveJob, v)
|
j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cannotBuildJob(err, v.Name)
|
return cannotBuildJob(err, v.Name)
|
||||||
}
|
}
|
||||||
case *config.PullJob:
|
case *config.PullJob:
|
||||||
j, err = activeSide(c, &v.ActiveJob, v)
|
j, err = activeSide(c, &v.ActiveJob, v, parseFlags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cannotBuildJob(err, v.Name)
|
return cannotBuildJob(err, v.Name)
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/zrepl/zrepl/config"
|
"github.com/zrepl/zrepl/config"
|
||||||
"github.com/zrepl/zrepl/transport/tls"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateReceivingSidesDoNotOverlap(t *testing.T) {
|
func TestValidateReceivingSidesDoNotOverlap(t *testing.T) {
|
||||||
@ -96,7 +95,7 @@ jobs:
|
|||||||
conf, err := config.ParseConfigBytes([]byte(fill(c.jobName)))
|
conf, err := config.ParseConfigBytes([]byte(fill(c.jobName)))
|
||||||
require.NoError(t, err, "not expecting yaml-config to know about job ids")
|
require.NoError(t, err, "not expecting yaml-config to know about job ids")
|
||||||
require.NotNil(t, conf)
|
require.NotNil(t, conf)
|
||||||
jobs, err := JobsFromConfig(conf)
|
jobs, err := JobsFromConfig(conf, config.ParseFlagsNone)
|
||||||
|
|
||||||
if c.valid {
|
if c.valid {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -153,8 +152,7 @@ func TestSampleConfigsAreBuiltWithoutErrors(t *testing.T) {
|
|||||||
t.Logf("file: %s", p)
|
t.Logf("file: %s", p)
|
||||||
t.Log(pretty.Sprint(c))
|
t.Log(pretty.Sprint(c))
|
||||||
|
|
||||||
tls.FakeCertificateLoading(t)
|
jobs, err := JobsFromConfig(c, config.ParseFlagsNoCertCheck)
|
||||||
jobs, err := JobsFromConfig(c)
|
|
||||||
t.Logf("jobs: %#v", jobs)
|
t.Logf("jobs: %#v", jobs)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -299,7 +297,7 @@ jobs:
|
|||||||
t.Logf("testing config:\n%s", cstr)
|
t.Logf("testing config:\n%s", cstr)
|
||||||
c, err := config.ParseConfigBytes([]byte(cstr))
|
c, err := config.ParseConfigBytes([]byte(cstr))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
jobs, err := JobsFromConfig(c)
|
jobs, err := JobsFromConfig(c, config.ParseFlagsNone)
|
||||||
if ts.expectOk != nil {
|
if ts.expectOk != nil {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, c)
|
require.NotNil(t, c)
|
||||||
|
@ -91,7 +91,7 @@ func (m *modeSource) SnapperReport() *snapper.Report {
|
|||||||
return m.snapper.Report()
|
return m.snapper.Report()
|
||||||
}
|
}
|
||||||
|
|
||||||
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}) (s *PassiveSide, err error) {
|
func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob interface{}, parseFlags config.ParseFlags) (s *PassiveSide, err error) {
|
||||||
|
|
||||||
s = &PassiveSide{}
|
s = &PassiveSide{}
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ func passiveSideFromConfig(g *config.Global, in *config.PassiveJob, configJob in
|
|||||||
return nil, err // no wrapping necessary
|
return nil, err // no wrapping necessary
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve); err != nil {
|
if s.listen, err = fromconfig.ListenerFactoryFromConfig(g, in.Serve, parseFlags); err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot build listener factory")
|
return nil, errors.Wrap(err, "cannot build listener factory")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ func main() {
|
|||||||
case "connect":
|
case "connect":
|
||||||
tc, err := getTestCase(args.testCase)
|
tc, err := getTestCase(args.testCase)
|
||||||
noerror(err)
|
noerror(err)
|
||||||
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect)
|
connecter, err := transportconfig.ConnecterFromConfig(global, conf.Connect, config.ParseFlagsNone)
|
||||||
noerror(err)
|
noerror(err)
|
||||||
wire, err := connecter.Connect(ctx)
|
wire, err := connecter.Connect(ctx)
|
||||||
noerror(err)
|
noerror(err)
|
||||||
@ -75,7 +75,7 @@ func main() {
|
|||||||
case "serve":
|
case "serve":
|
||||||
tc, err := getTestCase(args.testCase)
|
tc, err := getTestCase(args.testCase)
|
||||||
noerror(err)
|
noerror(err)
|
||||||
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve)
|
lf, err := transportconfig.ListenerFactoryFromConfig(global, conf.Serve, config.ParseFlagsNone)
|
||||||
noerror(err)
|
noerror(err)
|
||||||
l, err := lf()
|
l, err := lf()
|
||||||
noerror(err)
|
noerror(err)
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/zrepl/zrepl/transport/tls"
|
"github.com/zrepl/zrepl/transport/tls"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport.AuthenticatedListenerFactory, error) {
|
func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
l transport.AuthenticatedListenerFactory
|
l transport.AuthenticatedListenerFactory
|
||||||
@ -25,7 +25,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
|
|||||||
case *config.TCPServe:
|
case *config.TCPServe:
|
||||||
l, err = tcp.TCPListenerFactoryFromConfig(g, v)
|
l, err = tcp.TCPListenerFactoryFromConfig(g, v)
|
||||||
case *config.TLSServe:
|
case *config.TLSServe:
|
||||||
l, err = tls.TLSListenerFactoryFromConfig(g, v)
|
l, err = tls.TLSListenerFactoryFromConfig(g, v, parseFlags)
|
||||||
case *config.StdinserverServer:
|
case *config.StdinserverServer:
|
||||||
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
|
l, err = ssh.MultiStdinserverListenerFactoryFromConfig(g, v)
|
||||||
case *config.LocalServe:
|
case *config.LocalServe:
|
||||||
@ -37,7 +37,7 @@ func ListenerFactoryFromConfig(g *config.Global, in config.ServeEnum) (transport
|
|||||||
return l, err
|
return l, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Connecter, error) {
|
func ConnecterFromConfig(g *config.Global, in config.ConnectEnum, parseFlags config.ParseFlags) (transport.Connecter, error) {
|
||||||
var (
|
var (
|
||||||
connecter transport.Connecter
|
connecter transport.Connecter
|
||||||
err error
|
err error
|
||||||
@ -48,7 +48,7 @@ func ConnecterFromConfig(g *config.Global, in config.ConnectEnum) (transport.Con
|
|||||||
case *config.TCPConnect:
|
case *config.TCPConnect:
|
||||||
connecter, err = tcp.TCPConnecterFromConfig(v)
|
connecter, err = tcp.TCPConnecterFromConfig(v)
|
||||||
case *config.TLSConnect:
|
case *config.TLSConnect:
|
||||||
connecter, err = tls.TLSConnecterFromConfig(v)
|
connecter, err = tls.TLSConnecterFromConfig(v, parseFlags)
|
||||||
case *config.LocalConnect:
|
case *config.LocalConnect:
|
||||||
connecter, err = local.LocalConnecterFromConfig(v)
|
connecter, err = local.LocalConnecterFromConfig(v)
|
||||||
default:
|
default:
|
||||||
|
@ -18,12 +18,12 @@ type TLSConnecter struct {
|
|||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func TLSConnecterFromConfig(in *config.TLSConnect) (*TLSConnecter, error) {
|
func TLSConnecterFromConfig(in *config.TLSConnect, parseFlags config.ParseFlags) (*TLSConnecter, error) {
|
||||||
dialer := net.Dialer{
|
dialer := net.Dialer{
|
||||||
Timeout: in.DialTimeout,
|
Timeout: in.DialTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
if fakeCertificateLoading {
|
if parseFlags&config.ParseFlagsNoCertCheck != 0 {
|
||||||
return &TLSConnecter{in.Address, dialer, nil}, nil
|
return &TLSConnecter{in.Address, dialer, nil}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
|
|
||||||
type TLSListenerFactory struct{}
|
type TLSListenerFactory struct{}
|
||||||
|
|
||||||
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transport.AuthenticatedListenerFactory, error) {
|
func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe, parseFlags config.ParseFlags) (transport.AuthenticatedListenerFactory, error) {
|
||||||
|
|
||||||
address := in.Listen
|
address := in.Listen
|
||||||
handshakeTimeout := in.HandshakeTimeout
|
handshakeTimeout := in.HandshakeTimeout
|
||||||
@ -25,7 +25,7 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (transp
|
|||||||
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
if fakeCertificateLoading {
|
if parseFlags&config.ParseFlagsNoCertCheck != 0 {
|
||||||
return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil
|
return func() (transport.AuthenticatedListener, error) { return nil, nil }, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
package tls
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
var fakeCertificateLoading bool
|
|
||||||
|
|
||||||
func FakeCertificateLoading(t *testing.T) {
|
|
||||||
t.Logf("faking certificate loading")
|
|
||||||
fakeCertificateLoading = true
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user