job pull: refactor + use Task API

refs #10
This commit is contained in:
Christian Schwarz 2017-12-26 22:05:20 +01:00
parent b69089a527
commit 7d89d1fb00
2 changed files with 47 additions and 49 deletions

View File

@ -4,6 +4,7 @@ import (
"time" "time"
"context" "context"
"fmt"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/zrepl/zrepl/rpc" "github.com/zrepl/zrepl/rpc"
@ -98,16 +99,30 @@ func (j *PullJob) JobStart(ctx context.Context) {
log := ctx.Value(contextKeyLog).(Logger) log := ctx.Value(contextKeyLog).(Logger)
defer log.Info("exiting") defer log.Info("exiting")
j.task = NewTask("main", log) j.task = NewTask("main", log)
log = j.task.Log()
// j.task is idle here idle here
ticker := time.NewTicker(j.Interval) ticker := time.NewTicker(j.Interval)
for {
j.doRun(ctx)
select {
case <-ctx.Done():
j.task.Log().WithError(ctx.Err()).Info("context")
return
case <-ticker.C:
}
}
}
start: func (j *PullJob) doRun(ctx context.Context) {
log.Info("connecting") j.task.Enter("run")
defer j.task.Finish()
j.task.Log().Info("connecting")
rwc, err := j.Connect.Connect() rwc, err := j.Connect.Connect()
if err != nil { if err != nil {
log.WithError(err).Error("error connecting") j.task.Log().WithError(err).Error("error connecting")
return return
} }
@ -118,37 +133,24 @@ start:
client := rpc.NewClient(rwc) client := rpc.NewClient(rwc)
if j.Debug.RPC.Log { if j.Debug.RPC.Log {
client.SetLogger(log, true) client.SetLogger(j.task.Log(), true)
} }
log.Info("starting pull")
j.task.Enter("pull") j.task.Enter("pull")
puller := Puller{j.task, client, j.Mapping, j.InitialReplPolicy} puller := Puller{j.task, client, j.Mapping, j.InitialReplPolicy}
puller.Pull() puller.Pull()
j.task.Finish() j.task.Finish()
closeRPCWithTimeout(log, client, time.Second*10, "") closeRPCWithTimeout(j.task, client, time.Second*10, "")
log.Info("starting prune") j.task.Enter("prune")
prunectx := context.WithValue(ctx, contextKeyLog, log.WithField(logTaskField, "prune"))
pruner, err := j.Pruner(j.task, PrunePolicySideDefault, false) pruner, err := j.Pruner(j.task, PrunePolicySideDefault, false)
if err != nil { if err != nil {
log.WithError(err).Error("error creating pruner") j.task.Log().WithError(err).Error("error creating pruner")
return } else {
} pruner.Run(ctx)
pruner.Run(prunectx)
log.Info("finish prune")
log.Info("wait for next interval")
select {
case <-ctx.Done():
log.WithError(ctx.Err()).Info("context")
return
case <-ticker.C:
goto start
} }
j.task.Finish()
} }
@ -168,6 +170,26 @@ func (j *PullJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Prune
return return
} }
func (j *PullJob) doRun(ctx context.Context) { func closeRPCWithTimeout(task *Task, remote rpc.RPCClient, timeout time.Duration, goodbye string) {
task.Log().Info("closing rpc connection")
ch := make(chan error)
go func() {
ch <- remote.Close()
close(ch)
}()
var err error
select {
case <-time.After(timeout):
err = fmt.Errorf("timeout exceeded (%s)", timeout)
case closeRequestErr := <-ch:
err = closeRequestErr
}
if err != nil {
task.Log().WithError(err).Error("error closing connection")
}
return
} }

View File

@ -3,7 +3,6 @@ package cmd
import ( import (
"fmt" "fmt"
"io" "io"
"time"
"bytes" "bytes"
"encoding/json" "encoding/json"
@ -26,29 +25,6 @@ const (
InitialReplPolicyAll InitialReplPolicy = "all" InitialReplPolicyAll InitialReplPolicy = "all"
) )
func closeRPCWithTimeout(log Logger, remote rpc.RPCClient, timeout time.Duration, goodbye string) {
log.Info("closing rpc connection")
ch := make(chan error)
go func() {
ch <- remote.Close()
close(ch)
}()
var err error
select {
case <-time.After(timeout):
err = fmt.Errorf("timeout exceeded (%s)", timeout)
case closeRequestErr := <-ch:
err = closeRequestErr
}
if err != nil {
log.WithError(err).Error("error closing connection")
}
return
}
type Puller struct { type Puller struct {
task *Task task *Task
Remote rpc.RPCClient Remote rpc.RPCClient