diff --git a/config/config.go b/config/config.go index 02a7dae..3b76b00 100644 --- a/config/config.go +++ b/config/config.go @@ -6,8 +6,6 @@ import ( "log/syslog" "os" "reflect" - "regexp" - "strconv" "time" "github.com/pkg/errors" @@ -190,13 +188,10 @@ func (i *PositiveDurationOrManual) UnmarshalYAML(u func(interface{}, bool) error return fmt.Errorf("value must not be empty") default: i.Manual = false - i.Interval, err = time.ParseDuration(s) + i.Interval, err = parsePositiveDuration(s) if err != nil { return err } - if i.Interval <= 0 { - return fmt.Errorf("value must be a positive duration, got %q", s) - } } return nil } @@ -228,10 +223,10 @@ type SnapshottingEnum struct { } type SnapshottingPeriodic struct { - Type string `yaml:"type"` - Prefix string `yaml:"prefix"` - Interval time.Duration `yaml:"interval,positive"` - Hooks HookList `yaml:"hooks,optional"` + Type string `yaml:"type"` + Prefix string `yaml:"prefix"` + Interval *PositiveDuration `yaml:"interval"` + Hooks HookList `yaml:"hooks,optional"` } type CronSpec struct { @@ -715,41 +710,3 @@ func ParseConfigBytes(bytes []byte) (*Config, error) { } return c, nil } - -var durationStringRegex *regexp.Regexp = regexp.MustCompile(`^\s*(\d+)\s*(s|m|h|d|w)\s*$`) - -func parsePositiveDuration(e string) (d time.Duration, err error) { - comps := durationStringRegex.FindStringSubmatch(e) - if len(comps) != 3 { - err = fmt.Errorf("does not match regex: %s %#v", e, comps) - return - } - - durationFactor, err := strconv.ParseInt(comps[1], 10, 64) - if err != nil { - return 0, err - } - if durationFactor <= 0 { - return 0, errors.New("duration must be positive integer") - } - - var durationUnit time.Duration - switch comps[2] { - case "s": - durationUnit = time.Second - case "m": - durationUnit = time.Minute - case "h": - durationUnit = time.Hour - case "d": - durationUnit = 24 * time.Hour - case "w": - durationUnit = 24 * 7 * time.Hour - default: - err = fmt.Errorf("contains unknown time unit '%s'", comps[2]) - return - } - - d = time.Duration(durationFactor) * durationUnit - return -} diff --git a/config/config_duration.go b/config/config_duration.go new file mode 100644 index 0000000..91deabd --- /dev/null +++ b/config/config_duration.go @@ -0,0 +1,106 @@ +package config + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "time" + + "github.com/kr/pretty" + "github.com/zrepl/yaml-config" +) + +type Duration struct{ d time.Duration } + +func (d Duration) Duration() time.Duration { return d.d } + +var _ yaml.Unmarshaler = &Duration{} + +func (d *Duration) UnmarshalYAML(unmarshal func(v interface{}, not_strict bool) error) error { + var s string + err := unmarshal(&s, false) + if err != nil { + return err + } + d.d, err = parseDuration(s) + if err != nil { + d.d = 0 + return &yaml.TypeError{Errors: []string{fmt.Sprintf("cannot parse value %q: %s", s, err)}} + } + return nil +} + +type PositiveDuration struct{ d Duration } + +var _ yaml.Unmarshaler = &PositiveDuration{} + +func (d PositiveDuration) Duration() time.Duration { return d.d.Duration() } + +func (d *PositiveDuration) UnmarshalYAML(unmarshal func(v interface{}, not_strict bool) error) error { + err := d.d.UnmarshalYAML(unmarshal) + if err != nil { + return err + } + if d.d.Duration() <= 0 { + return fmt.Errorf("duration must be positive, got %s", d.d.Duration()) + } + return nil +} + +func parsePositiveDuration(e string) (time.Duration, error) { + d, err := parseDuration(e) + if err != nil { + return d, err + } + if d <= 0 { + return 0, errors.New("duration must be positive integer") + } + return d, err +} + +var durationStringRegex *regexp.Regexp = regexp.MustCompile(`^\s*([\+-]?\d+)\s*(|s|m|h|d|w)\s*$`) + +func parseDuration(e string) (d time.Duration, err error) { + comps := durationStringRegex.FindStringSubmatch(e) + if comps == nil { + err = fmt.Errorf("must match %s", durationStringRegex) + return + } + if len(comps) != 3 { + panic(pretty.Sprint(comps)) + } + + durationFactor, err := strconv.ParseInt(comps[1], 10, 64) + if err != nil { + return 0, err + } + + var durationUnit time.Duration + switch comps[2] { + case "": + if durationFactor != 0 { + err = fmt.Errorf("missing time unit") + return + } else { + // It's the case where user specified '0'. + // We want to allow this, just like time.ParseDuration. + } + case "s": + durationUnit = time.Second + case "m": + durationUnit = time.Minute + case "h": + durationUnit = time.Hour + case "d": + durationUnit = 24 * time.Hour + case "w": + durationUnit = 24 * 7 * time.Hour + default: + err = fmt.Errorf("contains unknown time unit '%s'", comps[2]) + return + } + + d = time.Duration(durationFactor) * durationUnit + return +} diff --git a/config/config_snapshotting_test.go b/config/config_snapshotting_test.go index 5bccf00..91ad312 100644 --- a/config/config_snapshotting_test.go +++ b/config/config_snapshotting_test.go @@ -38,6 +38,13 @@ jobs: interval: 10m ` + periodicDaily := ` + snapshotting: + type: periodic + prefix: zrepl_ + interval: 1d +` + hooks := ` snapshotting: type: periodic @@ -74,7 +81,15 @@ jobs: c = testValidConfig(t, fillSnapshotting(periodic)) snp := c.Jobs[0].Ret.(*PushJob).Snapshotting.Ret.(*SnapshottingPeriodic) assert.Equal(t, "periodic", snp.Type) - assert.Equal(t, 10*time.Minute, snp.Interval) + assert.Equal(t, 10*time.Minute, snp.Interval.Duration()) + assert.Equal(t, "zrepl_", snp.Prefix) + }) + + t.Run("periodicDaily", func(t *testing.T) { + c = testValidConfig(t, fillSnapshotting(periodicDaily)) + snp := c.Jobs[0].Ret.(*PushJob).Snapshotting.Ret.(*SnapshottingPeriodic) + assert.Equal(t, "periodic", snp.Type) + assert.Equal(t, 24*time.Hour, snp.Interval.Duration()) assert.Equal(t, "zrepl_", snp.Prefix) }) diff --git a/daemon/snapper/periodic.go b/daemon/snapper/periodic.go index 85e6824..d32306e 100644 --- a/daemon/snapper/periodic.go +++ b/daemon/snapper/periodic.go @@ -22,7 +22,7 @@ func periodicFromConfig(g *config.Global, fsf zfs.DatasetFilter, in *config.Snap if in.Prefix == "" { return nil, errors.New("prefix must not be empty") } - if in.Interval <= 0 { + if in.Interval.Duration() <= 0 { return nil, errors.New("interval must be positive") } @@ -32,7 +32,7 @@ func periodicFromConfig(g *config.Global, fsf zfs.DatasetFilter, in *config.Snap } args := periodicArgs{ - interval: in.Interval, + interval: in.Interval.Duration(), fsf: fsf, planArgs: planArgs{ prefix: in.Prefix,