mirror of
https://github.com/zrepl/zrepl.git
synced 2024-12-23 07:28:57 +01:00
cmd: remove global state in main.go
* refactoring * Now supporting default config locations
This commit is contained in:
parent
4ac7e78e2b
commit
9cd83399d3
@ -23,7 +23,7 @@ type LocalJob struct {
|
||||
Debug JobDebugSettings
|
||||
}
|
||||
|
||||
func parseLocalJob(name string, i map[string]interface{}) (j *LocalJob, err error) {
|
||||
func parseLocalJob(c JobParsingContext, name string, i map[string]interface{}) (j *LocalJob, err error) {
|
||||
|
||||
var asMap struct {
|
||||
Mapping map[string]string
|
||||
|
@ -23,7 +23,7 @@ type PullJob struct {
|
||||
Debug JobDebugSettings
|
||||
}
|
||||
|
||||
func parsePullJob(name string, i map[string]interface{}) (j *PullJob, err error) {
|
||||
func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j *PullJob, err error) {
|
||||
|
||||
var asMap struct {
|
||||
Connect map[string]interface{}
|
||||
|
@ -20,7 +20,7 @@ type SourceJob struct {
|
||||
Debug JobDebugSettings
|
||||
}
|
||||
|
||||
func parseSourceJob(name string, i map[string]interface{}) (j *SourceJob, err error) {
|
||||
func parseSourceJob(c JobParsingContext, name string, i map[string]interface{}) (j *SourceJob, err error) {
|
||||
|
||||
var asMap struct {
|
||||
Serve map[string]interface{}
|
||||
@ -38,7 +38,7 @@ func parseSourceJob(name string, i map[string]interface{}) (j *SourceJob, err er
|
||||
|
||||
j = &SourceJob{Name: name}
|
||||
|
||||
if j.Serve, err = parseAuthenticatedChannelListenerFactory(asMap.Serve); err != nil {
|
||||
if j.Serve, err = parseAuthenticatedChannelListenerFactory(c, asMap.Serve); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -3,13 +3,44 @@ package cmd
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
"os"
|
||||
)
|
||||
|
||||
func ParseConfig(path string) (config *Config, err error) {
|
||||
var ConfigFileDefaultLocations []string = []string{
|
||||
"/etc/zrepl/zrepl.yml",
|
||||
"/usr/local/etc/zrepl/zrepl.yml",
|
||||
}
|
||||
|
||||
type ConfigParsingContext struct {
|
||||
Global *Global
|
||||
}
|
||||
|
||||
func ParseConfig(ctx context.Context, path string) (config *Config, err error) {
|
||||
|
||||
log := ctx.Value(contextKeyLog).(Logger)
|
||||
|
||||
if path == "" {
|
||||
// Try default locations
|
||||
for _, l := range ConfigFileDefaultLocations {
|
||||
log.Printf("trying config location %s", l)
|
||||
stat, err := os.Stat(l)
|
||||
if err != nil {
|
||||
log.Printf("stat error: %s", err)
|
||||
continue
|
||||
}
|
||||
if !stat.Mode().IsRegular() {
|
||||
log.Printf("warning: file at default location is not a regular file: %s", l)
|
||||
continue
|
||||
}
|
||||
path = l
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var i interface{}
|
||||
|
||||
@ -48,10 +79,13 @@ func parseConfig(i interface{}) (c *Config, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
cpc := ConfigParsingContext{&c.Global}
|
||||
jpc := JobParsingContext{cpc}
|
||||
|
||||
// Parse Jobs
|
||||
c.Jobs = make(map[string]Job, len(asMap.Jobs))
|
||||
for i := range asMap.Jobs {
|
||||
job, err := parseJob(asMap.Jobs[i])
|
||||
job, err := parseJob(jpc, asMap.Jobs[i])
|
||||
if err != nil {
|
||||
// Try to find its name
|
||||
namei, ok := asMap.Jobs[i]["name"]
|
||||
@ -86,7 +120,11 @@ func extractStringField(i map[string]interface{}, key string, notempty bool) (fi
|
||||
return
|
||||
}
|
||||
|
||||
func parseJob(i map[string]interface{}) (j Job, err error) {
|
||||
type JobParsingContext struct {
|
||||
ConfigParsingContext
|
||||
}
|
||||
|
||||
func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error) {
|
||||
|
||||
name, err := extractStringField(i, "name", true)
|
||||
if err != nil {
|
||||
@ -101,11 +139,11 @@ func parseJob(i map[string]interface{}) (j Job, err error) {
|
||||
|
||||
switch jobtype {
|
||||
case "pull":
|
||||
return parsePullJob(name, i)
|
||||
return parsePullJob(c, name, i)
|
||||
case "source":
|
||||
return parseSourceJob(name, i)
|
||||
return parseSourceJob(c, name, i)
|
||||
case "local":
|
||||
return parseLocalJob(name, i)
|
||||
return parseLocalJob(c, name, i)
|
||||
default:
|
||||
return nil, errors.Errorf("unknown job type '%s'", jobtype)
|
||||
}
|
||||
@ -179,7 +217,7 @@ func parsePrunePolicy(v map[string]interface{}) (p PrunePolicy, err error) {
|
||||
|
||||
}
|
||||
|
||||
func parseAuthenticatedChannelListenerFactory(v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
|
||||
func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
|
||||
|
||||
t, err := extractStringField(v, "type", true)
|
||||
if err != nil {
|
||||
@ -188,7 +226,7 @@ func parseAuthenticatedChannelListenerFactory(v map[string]interface{}) (p Authe
|
||||
|
||||
switch t {
|
||||
case "stdinserver":
|
||||
return parseStdinserverListenerFactory(v)
|
||||
return parseStdinserverListenerFactory(c, v)
|
||||
default:
|
||||
err = errors.Errorf("unknown type '%s'", t)
|
||||
return
|
||||
|
@ -13,9 +13,10 @@ import (
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentity string `mapstructure:"client_identity"`
|
||||
sockaddr *net.UnixAddr
|
||||
}
|
||||
|
||||
func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||
func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||
|
||||
f = &StdinserverListenerFactory{}
|
||||
|
||||
@ -26,11 +27,17 @@ func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverLi
|
||||
err = errors.Errorf("must specify 'client_identity'")
|
||||
return
|
||||
}
|
||||
|
||||
f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func stdinserverListenerSockpath(clientIdentity string) (addr *net.UnixAddr, err error) {
|
||||
sockpath := path.Join(conf.Global.Serve.Stdinserver.SockDir, clientIdentity)
|
||||
func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) {
|
||||
sockpath := path.Join(sockdir, clientIdentity)
|
||||
addr, err = net.ResolveUnixAddr("unix", sockpath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot resolve unix address")
|
||||
@ -40,9 +47,7 @@ func stdinserverListenerSockpath(clientIdentity string) (addr *net.UnixAddr, err
|
||||
|
||||
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
||||
|
||||
unixaddr, err := stdinserverListenerSockpath(f.ClientIdentity)
|
||||
|
||||
sockdir := filepath.Dir(unixaddr.Name)
|
||||
sockdir := filepath.Dir(f.sockaddr.Name)
|
||||
sdstat, err := os.Stat(sockdir)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "cannot stat(2) sockdir '%s'", sockdir)
|
||||
@ -55,9 +60,9 @@ func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener,
|
||||
return nil, errors.Errorf("sockdir must not be world-accessible (permissions are %#o)", p)
|
||||
}
|
||||
|
||||
ul, err := net.ListenUnix("unix", unixaddr) // TODO
|
||||
ul, err := net.ListenUnix("unix", f.sockaddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", unixaddr)
|
||||
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr)
|
||||
}
|
||||
|
||||
l := &StdinserverListener{ul}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/spf13/cobra"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@ -35,8 +36,21 @@ type Job interface {
|
||||
}
|
||||
|
||||
func doDaemon(cmd *cobra.Command, args []string) {
|
||||
d := Daemon{}
|
||||
d.Loop()
|
||||
|
||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, contextKeyLog, log)
|
||||
|
||||
conf, err := ParseConfig(ctx, rootArgs.configFile)
|
||||
if err != nil {
|
||||
log.Printf("error parsing config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
d := NewDaemon(conf)
|
||||
d.Loop(ctx)
|
||||
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
@ -46,34 +60,38 @@ const (
|
||||
)
|
||||
|
||||
type Daemon struct {
|
||||
log Logger
|
||||
conf *Config
|
||||
}
|
||||
|
||||
func (d *Daemon) Loop() {
|
||||
func NewDaemon(initialConf *Config) *Daemon {
|
||||
return &Daemon{initialConf}
|
||||
}
|
||||
|
||||
func (d *Daemon) Loop(ctx context.Context) {
|
||||
|
||||
log := ctx.Value(contextKeyLog).(Logger)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
finishs := make(chan Job)
|
||||
cancels := make([]context.CancelFunc, len(conf.Jobs))
|
||||
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
log.Printf("starting jobs from config")
|
||||
i := 0
|
||||
for _, job := range conf.Jobs {
|
||||
for _, job := range d.conf.Jobs {
|
||||
log.Printf("starting job %s", job.JobName())
|
||||
|
||||
logger := jobLogger{log, job.JobName()}
|
||||
ctx := context.Background()
|
||||
ctx, cancels[i] = context.WithCancel(ctx)
|
||||
i++
|
||||
ctx = context.WithValue(ctx, contextKeyLog, logger)
|
||||
|
||||
jobCtx := context.WithValue(ctx, contextKeyLog, logger)
|
||||
go func(j Job) {
|
||||
j.JobStart(ctx)
|
||||
j.JobStart(jobCtx)
|
||||
finishs <- j
|
||||
}(job)
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
finishCount := 0
|
||||
outer:
|
||||
for {
|
||||
@ -81,7 +99,7 @@ outer:
|
||||
case j := <-finishs:
|
||||
log.Printf("job finished: %s", j.JobName())
|
||||
finishCount++
|
||||
if finishCount == len(conf.Jobs) {
|
||||
if finishCount == len(d.conf.Jobs) {
|
||||
log.Printf("all jobs finished")
|
||||
break outer
|
||||
}
|
||||
@ -89,10 +107,7 @@ outer:
|
||||
case sig := <-sigChan:
|
||||
log.Printf("received signal: %s", sig)
|
||||
log.Printf("cancelling all jobs")
|
||||
for _, c := range cancels {
|
||||
log.Printf("cancelling job")
|
||||
c()
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
|
24
cmd/main.go
24
cmd/main.go
@ -12,23 +12,14 @@ package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
golog "log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// global state / facilities
|
||||
var (
|
||||
conf *Config
|
||||
logFlags int = golog.LUTC | golog.Ldate | golog.Ltime
|
||||
log Logger
|
||||
)
|
||||
|
||||
var RootCmd = &cobra.Command{
|
||||
Use: "zrepl",
|
||||
Short: "ZFS dataset replication",
|
||||
@ -46,15 +37,13 @@ var rootArgs struct {
|
||||
}
|
||||
|
||||
func init() {
|
||||
cobra.OnInitialize(initConfig)
|
||||
//cobra.OnInitialize(initConfig)
|
||||
RootCmd.PersistentFlags().StringVar(&rootArgs.configFile, "config", "", "config file path")
|
||||
RootCmd.PersistentFlags().StringVar(&rootArgs.httpPprof, "debug.pprof.http", "", "run pprof http server on given port")
|
||||
}
|
||||
|
||||
func initConfig() {
|
||||
|
||||
log = golog.New(os.Stderr, "", logFlags)
|
||||
|
||||
// CPU profiling
|
||||
if rootArgs.httpPprof != "" {
|
||||
go func() {
|
||||
@ -62,17 +51,6 @@ func initConfig() {
|
||||
}()
|
||||
}
|
||||
|
||||
// Config
|
||||
if rootArgs.configFile == "" {
|
||||
log.Printf("config file not set")
|
||||
os.Exit(1)
|
||||
}
|
||||
var err error
|
||||
if conf, err = ParseConfig(rootArgs.configFile); err != nil {
|
||||
log.Printf("error parsing config: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
@ -4,9 +4,11 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"context"
|
||||
"github.com/ftrvxmtrx/fd"
|
||||
"github.com/spf13/cobra"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
)
|
||||
|
||||
@ -22,21 +24,27 @@ func init() {
|
||||
|
||||
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Printf("stdinserver exiting with error: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||
|
||||
die := func() {
|
||||
log.Printf("stdinserver exiting after fatal error")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), contextKeyLog, log)
|
||||
conf, err := ParseConfig(ctx, rootArgs.configFile)
|
||||
if err != nil {
|
||||
log.Printf("error parsing config: %s", err)
|
||||
die()
|
||||
}
|
||||
|
||||
if len(args) != 1 || args[0] == "" {
|
||||
err = fmt.Errorf("must specify client_identity as positional argument")
|
||||
return
|
||||
die()
|
||||
}
|
||||
identity := args[0]
|
||||
|
||||
unixaddr, err := stdinserverListenerSockpath(identity)
|
||||
unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
@ -46,14 +54,14 @@ func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||
conn, err := net.DialUnix("unix", nil, unixaddr)
|
||||
if err != nil {
|
||||
log.Printf("error connecting to zrepld: %s", err)
|
||||
os.Exit(1)
|
||||
die()
|
||||
}
|
||||
|
||||
log.Printf("sending stdin and stdout fds to zrepld")
|
||||
err = fd.Put(conn, os.Stdin, os.Stdout)
|
||||
if err != nil {
|
||||
log.Printf("error: %s", err)
|
||||
os.Exit(1)
|
||||
die()
|
||||
}
|
||||
|
||||
log.Printf("waiting for zrepld to close control connection")
|
||||
@ -73,11 +81,11 @@ func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||
neterr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
log.Printf("received unexpected error type: %T %s", err, err)
|
||||
os.Exit(1)
|
||||
die()
|
||||
}
|
||||
if !neterr.Timeout() {
|
||||
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
|
||||
os.Exit(1)
|
||||
die()
|
||||
}
|
||||
// Read timed out, as expected
|
||||
}
|
||||
|
31
cmd/test.go
31
cmd/test.go
@ -12,6 +12,7 @@ import (
|
||||
"github.com/kr/pretty"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
"log"
|
||||
)
|
||||
|
||||
var testCmd = &cobra.Command{
|
||||
@ -19,6 +20,11 @@ var testCmd = &cobra.Command{
|
||||
Short: "test configuration",
|
||||
}
|
||||
|
||||
var testCmdGlobal struct {
|
||||
log Logger
|
||||
conf *Config
|
||||
}
|
||||
|
||||
var testConfigSyntaxCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "parse config file and dump parsed datastructure",
|
||||
@ -45,6 +51,7 @@ var testPrunePolicyCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func init() {
|
||||
cobra.OnInitialize(testCmdGlobalInit)
|
||||
RootCmd.AddCommand(testCmd)
|
||||
testCmd.AddCommand(testConfigSyntaxCmd)
|
||||
testCmd.AddCommand(testDatasetMapFilter)
|
||||
@ -55,15 +62,33 @@ func init() {
|
||||
testCmd.AddCommand(testPrunePolicyCmd)
|
||||
}
|
||||
|
||||
func testCmdGlobalInit() {
|
||||
|
||||
testCmdGlobal.log = log.New(os.Stdout, "", 0)
|
||||
|
||||
ctx := context.WithValue(context.Background(), contextKeyLog, testCmdGlobal.log)
|
||||
|
||||
var err error
|
||||
if testCmdGlobal.conf, err = ParseConfig(ctx, rootArgs.configFile); err != nil {
|
||||
testCmdGlobal.log.Printf("error parsing config file: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func doTestConfig(cmd *cobra.Command, args []string) {
|
||||
|
||||
log, conf := testCmdGlobal.log, testCmdGlobal.conf
|
||||
|
||||
log.Printf("config ok")
|
||||
|
||||
log.Printf("%# v", pretty.Formatter(conf))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func doTestDatasetMapFilter(cmd *cobra.Command, args []string) {
|
||||
|
||||
log, conf := testCmdGlobal.log, testCmdGlobal.conf
|
||||
|
||||
if len(args) != 2 {
|
||||
log.Printf("specify job name as first postitional argument, test input as second")
|
||||
log.Printf(cmd.UsageString())
|
||||
@ -120,6 +145,8 @@ func doTestDatasetMapFilter(cmd *cobra.Command, args []string) {
|
||||
|
||||
func doTestPrunePolicy(cmd *cobra.Command, args []string) {
|
||||
|
||||
log, conf := testCmdGlobal.log, testCmdGlobal.conf
|
||||
|
||||
if cmd.Flags().NArg() != 1 {
|
||||
log.Printf("specify job name as first positional argument")
|
||||
log.Printf(cmd.UsageString())
|
||||
|
Loading…
Reference in New Issue
Block a user