cmd: remove global state in main.go

* refactoring
* Now supporting default config locations
This commit is contained in:
Christian Schwarz 2017-09-17 18:20:05 +02:00
parent 4ac7e78e2b
commit 9cd83399d3
9 changed files with 148 additions and 77 deletions

View File

@ -23,7 +23,7 @@ type LocalJob struct {
Debug JobDebugSettings Debug JobDebugSettings
} }
func parseLocalJob(name string, i map[string]interface{}) (j *LocalJob, err error) { func parseLocalJob(c JobParsingContext, name string, i map[string]interface{}) (j *LocalJob, err error) {
var asMap struct { var asMap struct {
Mapping map[string]string Mapping map[string]string

View File

@ -23,7 +23,7 @@ type PullJob struct {
Debug JobDebugSettings Debug JobDebugSettings
} }
func parsePullJob(name string, i map[string]interface{}) (j *PullJob, err error) { func parsePullJob(c JobParsingContext, name string, i map[string]interface{}) (j *PullJob, err error) {
var asMap struct { var asMap struct {
Connect map[string]interface{} Connect map[string]interface{}

View File

@ -20,7 +20,7 @@ type SourceJob struct {
Debug JobDebugSettings Debug JobDebugSettings
} }
func parseSourceJob(name string, i map[string]interface{}) (j *SourceJob, err error) { func parseSourceJob(c JobParsingContext, name string, i map[string]interface{}) (j *SourceJob, err error) {
var asMap struct { var asMap struct {
Serve map[string]interface{} Serve map[string]interface{}
@ -38,7 +38,7 @@ func parseSourceJob(name string, i map[string]interface{}) (j *SourceJob, err er
j = &SourceJob{Name: name} j = &SourceJob{Name: name}
if j.Serve, err = parseAuthenticatedChannelListenerFactory(asMap.Serve); err != nil { if j.Serve, err = parseAuthenticatedChannelListenerFactory(c, asMap.Serve); err != nil {
return return
} }

View File

@ -3,13 +3,44 @@ package cmd
import ( import (
"io/ioutil" "io/ioutil"
"context"
"fmt" "fmt"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
"os"
) )
func ParseConfig(path string) (config *Config, err error) { var ConfigFileDefaultLocations []string = []string{
"/etc/zrepl/zrepl.yml",
"/usr/local/etc/zrepl/zrepl.yml",
}
type ConfigParsingContext struct {
Global *Global
}
func ParseConfig(ctx context.Context, path string) (config *Config, err error) {
log := ctx.Value(contextKeyLog).(Logger)
if path == "" {
// Try default locations
for _, l := range ConfigFileDefaultLocations {
log.Printf("trying config location %s", l)
stat, err := os.Stat(l)
if err != nil {
log.Printf("stat error: %s", err)
continue
}
if !stat.Mode().IsRegular() {
log.Printf("warning: file at default location is not a regular file: %s", l)
continue
}
path = l
break
}
}
var i interface{} var i interface{}
@ -48,10 +79,13 @@ func parseConfig(i interface{}) (c *Config, err error) {
return return
} }
cpc := ConfigParsingContext{&c.Global}
jpc := JobParsingContext{cpc}
// Parse Jobs // Parse Jobs
c.Jobs = make(map[string]Job, len(asMap.Jobs)) c.Jobs = make(map[string]Job, len(asMap.Jobs))
for i := range asMap.Jobs { for i := range asMap.Jobs {
job, err := parseJob(asMap.Jobs[i]) job, err := parseJob(jpc, asMap.Jobs[i])
if err != nil { if err != nil {
// Try to find its name // Try to find its name
namei, ok := asMap.Jobs[i]["name"] namei, ok := asMap.Jobs[i]["name"]
@ -86,7 +120,11 @@ func extractStringField(i map[string]interface{}, key string, notempty bool) (fi
return return
} }
func parseJob(i map[string]interface{}) (j Job, err error) { type JobParsingContext struct {
ConfigParsingContext
}
func parseJob(c JobParsingContext, i map[string]interface{}) (j Job, err error) {
name, err := extractStringField(i, "name", true) name, err := extractStringField(i, "name", true)
if err != nil { if err != nil {
@ -101,11 +139,11 @@ func parseJob(i map[string]interface{}) (j Job, err error) {
switch jobtype { switch jobtype {
case "pull": case "pull":
return parsePullJob(name, i) return parsePullJob(c, name, i)
case "source": case "source":
return parseSourceJob(name, i) return parseSourceJob(c, name, i)
case "local": case "local":
return parseLocalJob(name, i) return parseLocalJob(c, name, i)
default: default:
return nil, errors.Errorf("unknown job type '%s'", jobtype) return nil, errors.Errorf("unknown job type '%s'", jobtype)
} }
@ -179,7 +217,7 @@ func parsePrunePolicy(v map[string]interface{}) (p PrunePolicy, err error) {
} }
func parseAuthenticatedChannelListenerFactory(v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) { func parseAuthenticatedChannelListenerFactory(c JobParsingContext, v map[string]interface{}) (p AuthenticatedChannelListenerFactory, err error) {
t, err := extractStringField(v, "type", true) t, err := extractStringField(v, "type", true)
if err != nil { if err != nil {
@ -188,7 +226,7 @@ func parseAuthenticatedChannelListenerFactory(v map[string]interface{}) (p Authe
switch t { switch t {
case "stdinserver": case "stdinserver":
return parseStdinserverListenerFactory(v) return parseStdinserverListenerFactory(c, v)
default: default:
err = errors.Errorf("unknown type '%s'", t) err = errors.Errorf("unknown type '%s'", t)
return return

View File

@ -13,9 +13,10 @@ import (
type StdinserverListenerFactory struct { type StdinserverListenerFactory struct {
ClientIdentity string `mapstructure:"client_identity"` ClientIdentity string `mapstructure:"client_identity"`
sockaddr *net.UnixAddr
} }
func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverListenerFactory, err error) { func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
f = &StdinserverListenerFactory{} f = &StdinserverListenerFactory{}
@ -26,11 +27,17 @@ func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverLi
err = errors.Errorf("must specify 'client_identity'") err = errors.Errorf("must specify 'client_identity'")
return return
} }
f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
if err != nil {
return
}
return return
} }
func stdinserverListenerSockpath(clientIdentity string) (addr *net.UnixAddr, err error) { func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) {
sockpath := path.Join(conf.Global.Serve.Stdinserver.SockDir, clientIdentity) sockpath := path.Join(sockdir, clientIdentity)
addr, err = net.ResolveUnixAddr("unix", sockpath) addr, err = net.ResolveUnixAddr("unix", sockpath)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot resolve unix address") return nil, errors.Wrap(err, "cannot resolve unix address")
@ -40,9 +47,7 @@ func stdinserverListenerSockpath(clientIdentity string) (addr *net.UnixAddr, err
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) { func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
unixaddr, err := stdinserverListenerSockpath(f.ClientIdentity) sockdir := filepath.Dir(f.sockaddr.Name)
sockdir := filepath.Dir(unixaddr.Name)
sdstat, err := os.Stat(sockdir) sdstat, err := os.Stat(sockdir)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "cannot stat(2) sockdir '%s'", sockdir) return nil, errors.Wrapf(err, "cannot stat(2) sockdir '%s'", sockdir)
@ -55,9 +60,9 @@ func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener,
return nil, errors.Errorf("sockdir must not be world-accessible (permissions are %#o)", p) return nil, errors.Errorf("sockdir must not be world-accessible (permissions are %#o)", p)
} }
ul, err := net.ListenUnix("unix", unixaddr) // TODO ul, err := net.ListenUnix("unix", f.sockaddr)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", unixaddr) return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr)
} }
l := &StdinserverListener{ul} l := &StdinserverListener{ul}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"log"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@ -35,8 +36,21 @@ type Job interface {
} }
func doDaemon(cmd *cobra.Command, args []string) { func doDaemon(cmd *cobra.Command, args []string) {
d := Daemon{}
d.Loop() log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
ctx := context.Background()
ctx = context.WithValue(ctx, contextKeyLog, log)
conf, err := ParseConfig(ctx, rootArgs.configFile)
if err != nil {
log.Printf("error parsing config: %s", err)
os.Exit(1)
}
d := NewDaemon(conf)
d.Loop(ctx)
} }
type contextKey string type contextKey string
@ -46,34 +60,38 @@ const (
) )
type Daemon struct { type Daemon struct {
log Logger conf *Config
} }
func (d *Daemon) Loop() { func NewDaemon(initialConf *Config) *Daemon {
return &Daemon{initialConf}
}
func (d *Daemon) Loop(ctx context.Context) {
log := ctx.Value(contextKeyLog).(Logger)
ctx, cancel := context.WithCancel(ctx)
sigChan := make(chan os.Signal, 1)
finishs := make(chan Job) finishs := make(chan Job)
cancels := make([]context.CancelFunc, len(conf.Jobs))
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Printf("starting jobs from config") log.Printf("starting jobs from config")
i := 0 i := 0
for _, job := range conf.Jobs { for _, job := range d.conf.Jobs {
log.Printf("starting job %s", job.JobName()) log.Printf("starting job %s", job.JobName())
logger := jobLogger{log, job.JobName()} logger := jobLogger{log, job.JobName()}
ctx := context.Background()
ctx, cancels[i] = context.WithCancel(ctx)
i++ i++
ctx = context.WithValue(ctx, contextKeyLog, logger) jobCtx := context.WithValue(ctx, contextKeyLog, logger)
go func(j Job) { go func(j Job) {
j.JobStart(ctx) j.JobStart(jobCtx)
finishs <- j finishs <- j
}(job) }(job)
} }
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
finishCount := 0 finishCount := 0
outer: outer:
for { for {
@ -81,7 +99,7 @@ outer:
case j := <-finishs: case j := <-finishs:
log.Printf("job finished: %s", j.JobName()) log.Printf("job finished: %s", j.JobName())
finishCount++ finishCount++
if finishCount == len(conf.Jobs) { if finishCount == len(d.conf.Jobs) {
log.Printf("all jobs finished") log.Printf("all jobs finished")
break outer break outer
} }
@ -89,10 +107,7 @@ outer:
case sig := <-sigChan: case sig := <-sigChan:
log.Printf("received signal: %s", sig) log.Printf("received signal: %s", sig)
log.Printf("cancelling all jobs") log.Printf("cancelling all jobs")
for _, c := range cancels { cancel()
log.Printf("cancelling job")
c()
}
} }
} }

View File

@ -12,23 +12,14 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
golog "log"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os"
) )
type Logger interface { type Logger interface {
Printf(format string, v ...interface{}) Printf(format string, v ...interface{})
} }
// global state / facilities
var (
conf *Config
logFlags int = golog.LUTC | golog.Ldate | golog.Ltime
log Logger
)
var RootCmd = &cobra.Command{ var RootCmd = &cobra.Command{
Use: "zrepl", Use: "zrepl",
Short: "ZFS dataset replication", Short: "ZFS dataset replication",
@ -46,15 +37,13 @@ var rootArgs struct {
} }
func init() { func init() {
cobra.OnInitialize(initConfig) //cobra.OnInitialize(initConfig)
RootCmd.PersistentFlags().StringVar(&rootArgs.configFile, "config", "", "config file path") RootCmd.PersistentFlags().StringVar(&rootArgs.configFile, "config", "", "config file path")
RootCmd.PersistentFlags().StringVar(&rootArgs.httpPprof, "debug.pprof.http", "", "run pprof http server on given port") RootCmd.PersistentFlags().StringVar(&rootArgs.httpPprof, "debug.pprof.http", "", "run pprof http server on given port")
} }
func initConfig() { func initConfig() {
log = golog.New(os.Stderr, "", logFlags)
// CPU profiling // CPU profiling
if rootArgs.httpPprof != "" { if rootArgs.httpPprof != "" {
go func() { go func() {
@ -62,17 +51,6 @@ func initConfig() {
}() }()
} }
// Config
if rootArgs.configFile == "" {
log.Printf("config file not set")
os.Exit(1)
}
var err error
if conf, err = ParseConfig(rootArgs.configFile); err != nil {
log.Printf("error parsing config: %s", err)
os.Exit(1)
}
return return
} }

View File

@ -4,9 +4,11 @@ import (
"fmt" "fmt"
"os" "os"
"context"
"github.com/ftrvxmtrx/fd" "github.com/ftrvxmtrx/fd"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"io" "io"
"log"
"net" "net"
) )
@ -22,21 +24,27 @@ func init() {
func cmdStdinServer(cmd *cobra.Command, args []string) { func cmdStdinServer(cmd *cobra.Command, args []string) {
var err error log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
defer func() {
if err != nil { die := func() {
log.Printf("stdinserver exiting with error: %s", err) log.Printf("stdinserver exiting after fatal error")
os.Exit(1) os.Exit(1)
} }
}()
ctx := context.WithValue(context.Background(), contextKeyLog, log)
conf, err := ParseConfig(ctx, rootArgs.configFile)
if err != nil {
log.Printf("error parsing config: %s", err)
die()
}
if len(args) != 1 || args[0] == "" { if len(args) != 1 || args[0] == "" {
err = fmt.Errorf("must specify client_identity as positional argument") err = fmt.Errorf("must specify client_identity as positional argument")
return die()
} }
identity := args[0] identity := args[0]
unixaddr, err := stdinserverListenerSockpath(identity) unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
if err != nil { if err != nil {
log.Printf("%s", err) log.Printf("%s", err)
os.Exit(1) os.Exit(1)
@ -46,14 +54,14 @@ func cmdStdinServer(cmd *cobra.Command, args []string) {
conn, err := net.DialUnix("unix", nil, unixaddr) conn, err := net.DialUnix("unix", nil, unixaddr)
if err != nil { if err != nil {
log.Printf("error connecting to zrepld: %s", err) log.Printf("error connecting to zrepld: %s", err)
os.Exit(1) die()
} }
log.Printf("sending stdin and stdout fds to zrepld") log.Printf("sending stdin and stdout fds to zrepld")
err = fd.Put(conn, os.Stdin, os.Stdout) err = fd.Put(conn, os.Stdin, os.Stdout)
if err != nil { if err != nil {
log.Printf("error: %s", err) log.Printf("error: %s", err)
os.Exit(1) die()
} }
log.Printf("waiting for zrepld to close control connection") log.Printf("waiting for zrepld to close control connection")
@ -73,11 +81,11 @@ func cmdStdinServer(cmd *cobra.Command, args []string) {
neterr, ok := err.(net.Error) neterr, ok := err.(net.Error)
if !ok { if !ok {
log.Printf("received unexpected error type: %T %s", err, err) log.Printf("received unexpected error type: %T %s", err, err)
os.Exit(1) die()
} }
if !neterr.Timeout() { if !neterr.Timeout() {
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr) log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
os.Exit(1) die()
} }
// Read timed out, as expected // Read timed out, as expected
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/kr/pretty" "github.com/kr/pretty"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"log"
) )
var testCmd = &cobra.Command{ var testCmd = &cobra.Command{
@ -19,6 +20,11 @@ var testCmd = &cobra.Command{
Short: "test configuration", Short: "test configuration",
} }
var testCmdGlobal struct {
log Logger
conf *Config
}
var testConfigSyntaxCmd = &cobra.Command{ var testConfigSyntaxCmd = &cobra.Command{
Use: "config", Use: "config",
Short: "parse config file and dump parsed datastructure", Short: "parse config file and dump parsed datastructure",
@ -45,6 +51,7 @@ var testPrunePolicyCmd = &cobra.Command{
} }
func init() { func init() {
cobra.OnInitialize(testCmdGlobalInit)
RootCmd.AddCommand(testCmd) RootCmd.AddCommand(testCmd)
testCmd.AddCommand(testConfigSyntaxCmd) testCmd.AddCommand(testConfigSyntaxCmd)
testCmd.AddCommand(testDatasetMapFilter) testCmd.AddCommand(testDatasetMapFilter)
@ -55,15 +62,33 @@ func init() {
testCmd.AddCommand(testPrunePolicyCmd) testCmd.AddCommand(testPrunePolicyCmd)
} }
func testCmdGlobalInit() {
testCmdGlobal.log = log.New(os.Stdout, "", 0)
ctx := context.WithValue(context.Background(), contextKeyLog, testCmdGlobal.log)
var err error
if testCmdGlobal.conf, err = ParseConfig(ctx, rootArgs.configFile); err != nil {
testCmdGlobal.log.Printf("error parsing config file: %s", err)
os.Exit(1)
}
}
func doTestConfig(cmd *cobra.Command, args []string) { func doTestConfig(cmd *cobra.Command, args []string) {
log, conf := testCmdGlobal.log, testCmdGlobal.conf
log.Printf("config ok") log.Printf("config ok")
log.Printf("%# v", pretty.Formatter(conf)) log.Printf("%# v", pretty.Formatter(conf))
return return
} }
func doTestDatasetMapFilter(cmd *cobra.Command, args []string) { func doTestDatasetMapFilter(cmd *cobra.Command, args []string) {
log, conf := testCmdGlobal.log, testCmdGlobal.conf
if len(args) != 2 { if len(args) != 2 {
log.Printf("specify job name as first postitional argument, test input as second") log.Printf("specify job name as first postitional argument, test input as second")
log.Printf(cmd.UsageString()) log.Printf(cmd.UsageString())
@ -120,6 +145,8 @@ func doTestDatasetMapFilter(cmd *cobra.Command, args []string) {
func doTestPrunePolicy(cmd *cobra.Command, args []string) { func doTestPrunePolicy(cmd *cobra.Command, args []string) {
log, conf := testCmdGlobal.log, testCmdGlobal.conf
if cmd.Flags().NArg() != 1 { if cmd.Flags().NArg() != 1 {
log.Printf("specify job name as first positional argument") log.Printf("specify job name as first positional argument")
log.Printf(cmd.UsageString()) log.Printf(cmd.UsageString())