mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-16 03:10:32 +01:00
83d450b1f2
fixes #18
296 lines
6.0 KiB
Go
296 lines
6.0 KiB
Go
package cmd
|
|
|
|
import (
|
|
"io/ioutil"
|
|
|
|
"fmt"
|
|
yaml "github.com/go-yaml/yaml"
|
|
"github.com/mitchellh/mapstructure"
|
|
"github.com/pkg/errors"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
var ConfigFileDefaultLocations []string = []string{
|
|
"/etc/zrepl/zrepl.yml",
|
|
"/usr/local/etc/zrepl/zrepl.yml",
|
|
}
|
|
|
|
const (
|
|
JobNameControl string = "control"
|
|
)
|
|
|
|
var ReservedJobNames []string = []string{
|
|
JobNameControl,
|
|
}
|
|
|
|
type ConfigParsingContext struct {
|
|
Global *Global
|
|
}
|
|
|
|
func ParseConfig(path string) (config *Config, err error) {
|
|
|
|
if path == "" {
|
|
// Try default locations
|
|
for _, l := range ConfigFileDefaultLocations {
|
|
stat, err := os.Stat(l)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if !stat.Mode().IsRegular() {
|
|
err = errors.Errorf("file at default location is not a regular file: %s", l)
|
|
continue
|
|
}
|
|
path = l
|
|
break
|
|
}
|
|
}
|
|
|
|
var i interface{}
|
|
|
|
var bytes []byte
|
|
|
|
if bytes, err = ioutil.ReadFile(path); err != nil {
|
|
err = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
if err = yaml.Unmarshal(bytes, &i); err != nil {
|
|
err = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
return parseConfig(i)
|
|
}
|
|
|
|
func parseConfig(i interface{}) (c *Config, err error) {
|
|
|
|
var asMap struct {
|
|
Global map[string]interface{}
|
|
Jobs []map[string]interface{}
|
|
}
|
|
if err := mapstructure.Decode(i, &asMap); err != nil {
|
|
return nil, errors.Wrap(err, "config root must be a dict")
|
|
}
|
|
|
|
c = &Config{}
|
|
|
|
// Parse global with defaults
|
|
c.Global.Serve.Stdinserver.SockDir = "/var/run/zrepl/stdinserver"
|
|
c.Global.Control.Sockpath = "/var/run/zrepl/control"
|
|
|
|
err = mapstructure.Decode(asMap.Global, &c.Global)
|
|
if err != nil {
|
|
err = errors.Wrap(err, "mapstructure error on 'global' section: %s")
|
|
return
|
|
}
|
|
|
|
if c.Global.logging, err = parseLogging(asMap.Global["logging"]); err != nil {
|
|
return nil, errors.Wrap(err, "cannot parse logging section")
|
|
}
|
|
|
|
cpc := ConfigParsingContext{&c.Global}
|
|
jpc := JobParsingContext{cpc}
|
|
|
|
// Jobs
|
|
c.Jobs = make(map[string]Job, len(asMap.Jobs))
|
|
for i := range asMap.Jobs {
|
|
job, err := parseJob(jpc, asMap.Jobs[i])
|
|
if err != nil {
|
|
// Try to find its name
|
|
namei, ok := asMap.Jobs[i]["name"]
|
|
if !ok {
|
|
namei = fmt.Sprintf("<no name, entry #%d in list>", i)
|
|
}
|
|
err = errors.Wrapf(err, "cannot parse job '%v'", namei)
|
|
return nil, err
|
|
}
|
|
jn := job.JobName()
|
|
if _, ok := c.Jobs[jn]; ok {
|
|
err = errors.Errorf("duplicate job name: %s", jn)
|
|
return nil, err
|
|
}
|
|
c.Jobs[job.JobName()] = job
|
|
}
|
|
|
|
cj, err := NewControlJob(JobNameControl, jpc.Global.Control.Sockpath)
|
|
if err != nil {
|
|
err = errors.Wrap(err, "cannot create control job")
|
|
return
|
|
}
|
|
c.Jobs[JobNameControl] = cj
|
|
|
|
return c, nil
|
|
|
|
}
|
|
|
|
func extractStringField(i map[string]interface{}, key string, notempty bool) (field string, err error) {
|
|
vi, ok := i[key]
|
|
if !ok {
|
|
err = errors.Errorf("must have field '%s'", key)
|
|
return "", err
|
|
}
|
|
field, ok = vi.(string)
|
|
if !ok {
|
|
err = errors.Errorf("'%s' field must have type string", key)
|
|
return "", err
|
|
}
|
|
if notempty && len(field) <= 0 {
|
|
err = errors.Errorf("'%s' field must not be empty", key)
|
|
return "", err
|
|
}
|
|
return
|
|
}
|
|
|
|
type JobParsingContext struct {
|
|
ConfigParsingContext
|
|
}
|
|
|
|
func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error) {
|
|
|
|
name, err := extractStringField(i, "name", true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
for _, r := range ReservedJobNames {
|
|
if name == r {
|
|
err = errors.Errorf("job name '%s' is reserved", name)
|
|
return
|
|
}
|
|
}
|
|
|
|
jobtype, err := extractStringField(i, "type", true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
switch jobtype {
|
|
case "pull":
|
|
return parsePullJob(c, name, i)
|
|
case "source":
|
|
return parseSourceJob(c, name, i)
|
|
case "local":
|
|
return parseLocalJob(c, name, i)
|
|
default:
|
|
return nil, errors.Errorf("unknown job type '%s'", jobtype)
|
|
}
|
|
|
|
}
|
|
|
|
func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) {
|
|
|
|
t, err := extractStringField(i, "type", true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t {
|
|
case "ssh+stdinserver":
|
|
return parseSSHStdinserverConnecter(i)
|
|
default:
|
|
return nil, errors.Errorf("unknown connection type '%s'", t)
|
|
}
|
|
|
|
}
|
|
|
|
func parseInitialReplPolicy(v interface{}, defaultPolicy InitialReplPolicy) (p InitialReplPolicy, err error) {
|
|
s, ok := v.(string)
|
|
if !ok {
|
|
goto err
|
|
}
|
|
|
|
switch {
|
|
case s == "":
|
|
p = defaultPolicy
|
|
case s == "most_recent":
|
|
p = InitialReplPolicyMostRecent
|
|
case s == "all":
|
|
p = InitialReplPolicyAll
|
|
default:
|
|
goto err
|
|
}
|
|
|
|
return
|
|
|
|
err:
|
|
err = errors.New(fmt.Sprintf("expected InitialReplPolicy, got %#v", v))
|
|
return
|
|
}
|
|
|
|
func parsePrunePolicy(v map[string]interface{}) (p PrunePolicy, err error) {
|
|
|
|
policyName, err := extractStringField(v, "policy", true)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
switch policyName {
|
|
case "grid":
|
|
return parseGridPrunePolicy(v)
|
|
case "noprune":
|
|
return NoPrunePolicy{}, nil
|
|
default:
|
|
err = errors.Errorf("unknown policy '%s'", policyName)
|
|
return
|
|
}
|
|
|
|
}
|
|
|
|
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
|
|
|
|
t, err := extractStringField(v, "type", true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch t {
|
|
case "stdinserver":
|
|
return parseStdinserverListenerFactory(c, v)
|
|
default:
|
|
err = errors.Errorf("unknown type '%s'", t)
|
|
return
|
|
}
|
|
|
|
}
|
|
|
|
var durationStringRegex *regexp.Regexp = regexp.MustCompile(`^\s*(\d+)\s*(s|m|h|d|w)\s*$`)
|
|
|
|
func parsePostitiveDuration(e string) (d time.Duration, err error) {
|
|
comps := durationStringRegex.FindStringSubmatch(e)
|
|
if len(comps) != 3 {
|
|
err = fmt.Errorf("does not match regex: %s %#v", e, comps)
|
|
return
|
|
}
|
|
|
|
durationFactor, err := strconv.ParseInt(comps[1], 10, 64)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if durationFactor <= 0 {
|
|
return 0, errors.New("duration must be positive integer")
|
|
}
|
|
|
|
var durationUnit time.Duration
|
|
switch comps[2] {
|
|
case "s":
|
|
durationUnit = time.Second
|
|
case "m":
|
|
durationUnit = time.Minute
|
|
case "h":
|
|
durationUnit = time.Hour
|
|
case "d":
|
|
durationUnit = 24 * time.Hour
|
|
case "w":
|
|
durationUnit = 24 * 7 * time.Hour
|
|
default:
|
|
err = fmt.Errorf("contains unknown time unit '%s'", comps[2])
|
|
return
|
|
}
|
|
|
|
d = time.Duration(durationFactor) * durationUnit
|
|
return
|
|
}
|