package cmd import ( "io/ioutil" "fmt" yaml "github.com/go-yaml/yaml" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" "os" "regexp" "strconv" "time" ) var ConfigFileDefaultLocations []string = []string{ "/etc/zrepl/zrepl.yml", "/usr/local/etc/zrepl/zrepl.yml", } const ( JobNameControl string = "control" ) var ReservedJobNames []string = []string{ JobNameControl, } type ConfigParsingContext struct { Global *Global } func ParseConfig(path string) (config *Config, err error) { if path == "" { // Try default locations for _, l := range ConfigFileDefaultLocations { stat, err := os.Stat(l) if err != nil { continue } if !stat.Mode().IsRegular() { err = errors.Errorf("file at default location is not a regular file: %s", l) continue } path = l break } } var i interface{} var bytes []byte if bytes, err = ioutil.ReadFile(path); err != nil { err = errors.WithStack(err) return } if err = yaml.Unmarshal(bytes, &i); err != nil { err = errors.WithStack(err) return } return parseConfig(i) } func parseConfig(i interface{}) (c *Config, err error) { var asMap struct { Global map[string]interface{} Jobs []map[string]interface{} } if err := mapstructure.Decode(i, &asMap); err != nil { return nil, errors.Wrap(err, "config root must be a dict") } c = &Config{} // Parse global with defaults c.Global.Serve.Stdinserver.SockDir = "/var/run/zrepl/stdinserver" c.Global.Control.Sockpath = "/var/run/zrepl/control" err = mapstructure.Decode(asMap.Global, &c.Global) if err != nil { err = errors.Wrap(err, "mapstructure error on 'global' section: %s") return } if c.Global.logging, err = parseLogging(asMap.Global["logging"]); err != nil { return nil, errors.Wrap(err, "cannot parse logging section") } cpc := ConfigParsingContext{&c.Global} jpc := JobParsingContext{cpc} // Jobs c.Jobs = make(map[string]Job, len(asMap.Jobs)) for i := range asMap.Jobs { job, err := parseJob(jpc, asMap.Jobs[i]) if err != nil { // Try to find its name namei, ok := asMap.Jobs[i]["name"] if !ok { namei = fmt.Sprintf("", i) } err = errors.Wrapf(err, "cannot parse job '%v'", namei) return nil, err } jn := job.JobName() if _, ok := c.Jobs[jn]; ok { err = errors.Errorf("duplicate job name: %s", jn) return nil, err } c.Jobs[job.JobName()] = job } cj, err := NewControlJob(JobNameControl, jpc.Global.Control.Sockpath) if err != nil { err = errors.Wrap(err, "cannot create control job") return } c.Jobs[JobNameControl] = cj return c, nil } func extractStringField(i map[string]interface{}, key string, notempty bool) (field string, err error) { vi, ok := i[key] if !ok { err = errors.Errorf("must have field '%s'", key) return "", err } field, ok = vi.(string) if !ok { err = errors.Errorf("'%s' field must have type string", key) return "", err } if notempty && len(field) <= 0 { err = errors.Errorf("'%s' field must not be empty", key) return "", err } return } type JobParsingContext struct { ConfigParsingContext } func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error) { name, err := extractStringField(i, "name", true) if err != nil { return } for _, r := range ReservedJobNames { if name == r { err = errors.Errorf("job name '%s' is reserved", name) return } } jobtype, err := extractStringField(i, "type", true) if err != nil { return } switch jobtype { case "pull": return parsePullJob(c, name, i) case "source": return parseSourceJob(c, name, i) case "local": return parseLocalJob(c, name, i) default: return nil, errors.Errorf("unknown job type '%s'", jobtype) } } func parseConnect(i map[string]interface{}) (c RWCConnecter, err error) { t, err := extractStringField(i, "type", true) if err != nil { return nil, err } switch t { case "ssh+stdinserver": return parseSSHStdinserverConnecter(i) default: return nil, errors.Errorf("unknown connection type '%s'", t) } } func parseInitialReplPolicy(v interface{}, defaultPolicy InitialReplPolicy) (p InitialReplPolicy, err error) { s, ok := v.(string) if !ok { goto err } switch { case s == "": p = defaultPolicy case s == "most_recent": p = InitialReplPolicyMostRecent case s == "all": p = InitialReplPolicyAll default: goto err } return err: err = errors.New(fmt.Sprintf("expected InitialReplPolicy, got %#v", v)) return } func parsePrunePolicy(v map[string]interface{}) (p PrunePolicy, err error) { policyName, err := extractStringField(v, "policy", true) if err != nil { return } switch policyName { case "grid": return parseGridPrunePolicy(v) case "noprune": return NoPrunePolicy{}, nil default: err = errors.Errorf("unknown policy '%s'", policyName) return } } func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) { t, err := extractStringField(v, "type", true) if err != nil { return nil, err } switch t { case "stdinserver": return parseStdinserverListenerFactory(c, v) default: err = errors.Errorf("unknown type '%s'", t) return } } var durationStringRegex *regexp.Regexp = regexp.MustCompile(`^\s*(\d+)\s*(s|m|h|d|w)\s*$`) func parsePostitiveDuration(e string) (d time.Duration, err error) { comps := durationStringRegex.FindStringSubmatch(e) if len(comps) != 3 { err = fmt.Errorf("does not match regex: %s %#v", e, comps) return } durationFactor, err := strconv.ParseInt(comps[1], 10, 64) if err != nil { return 0, err } if durationFactor <= 0 { return 0, errors.New("duration must be positive integer") } var durationUnit time.Duration switch comps[2] { case "s": durationUnit = time.Second case "m": durationUnit = time.Minute case "h": durationUnit = time.Hour case "d": durationUnit = 24 * time.Hour case "w": durationUnit = 24 * 7 * time.Hour default: err = fmt.Errorf("contains unknown time unit '%s'", comps[2]) return } d = time.Duration(durationFactor) * durationUnit return }