cmd: handler: check FilesystemVersionFilter as part of ACL

This commit is contained in:
Christian Schwarz 2017-09-16 20:24:46 +02:00
parent dc3378e890
commit e3ec093d53
3 changed files with 44 additions and 25 deletions

View File

@ -12,7 +12,7 @@ import (
type LocalJob struct { type LocalJob struct {
Name string Name string
Mapping *DatasetMapFilter Mapping *DatasetMapFilter
SnapshotFilter *PrefixSnapshotFilter SnapshotPrefix string
Interval time.Duration Interval time.Duration
InitialReplPolicy InitialReplPolicy InitialReplPolicy InitialReplPolicy
PruneLHS PrunePolicy PruneLHS PrunePolicy
@ -43,7 +43,7 @@ func parseLocalJob(name string, i map[string]interface{}) (j *LocalJob, err erro
return return
} }
if j.SnapshotFilter, err = parsePrefixSnapshotFilter(asMap.SnapshotPrefix); err != nil { if j.SnapshotPrefix, err = parseSnapshotPrefix(asMap.SnapshotPrefix); err != nil {
return return
} }
@ -82,15 +82,13 @@ func (j *LocalJob) JobStart(ctx context.Context) {
log := ctx.Value(contextKeyLog).(Logger) log := ctx.Value(contextKeyLog).(Logger)
local := rpc.NewLocalRPC() local := rpc.NewLocalRPC()
handler := Handler{
Logger: log,
// Allow access to any dataset since we control what mapping // Allow access to any dataset since we control what mapping
// is passed to the pull routine. // is passed to the pull routine.
// All local datasets will be passed to its Map() function, // All local datasets will be passed to its Map() function,
// but only those for which a mapping exists will actually be pulled. // but only those for which a mapping exists will actually be pulled.
// We can pay this small performance penalty for now. // We can pay this small performance penalty for now.
PullACL: localPullACL{}, handler := NewHandler(log, localPullACL{}, &PrefixSnapshotFilter{j.SnapshotPrefix})
}
registerEndpoints(local, handler) registerEndpoints(local, handler)
err := doPull(PullContext{local, log, j.Mapping, j.InitialReplPolicy}) err := doPull(PullContext{local, log, j.Mapping, j.InitialReplPolicy})

View File

@ -146,10 +146,7 @@ outer:
} }
// construct connection handler // construct connection handler
handler := Handler{ handler := NewHandler(log, j.Datasets, &PrefixSnapshotFilter{j.SnapshotPrefix})
Logger: log,
PullACL: j.Datasets,
}
// handle connection // handle connection
rpcServer := rpc.NewServer(rwc) rpcServer := rpc.NewServer(rwc)

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/rpc" "github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
) )
@ -34,6 +35,11 @@ type IncrementalTransferRequest struct {
type Handler struct { type Handler struct {
Logger Logger Logger Logger
PullACL zfs.DatasetFilter PullACL zfs.DatasetFilter
VersionFilter zfs.FilesystemVersionFilter
}
func NewHandler(logger Logger, dsfilter zfs.DatasetFilter, snapfilter zfs.FilesystemVersionFilter) (h Handler) {
return Handler{logger, dsfilter, snapfilter}
} }
func registerEndpoints(server rpc.RPCServer, handler Handler) (err error) { func registerEndpoints(server rpc.RPCServer, handler Handler) (err error) {
@ -78,12 +84,12 @@ func (h Handler) HandleFilesystemVersionsRequest(r *FilesystemVersionsRequest, v
h.Logger.Printf("handling filesystem versions request: %#v", r) h.Logger.Printf("handling filesystem versions request: %#v", r)
// allowed to request that? // allowed to request that?
if h.pullACLCheck(r.Filesystem); err != nil { if h.pullACLCheck(r.Filesystem, nil); err != nil {
return return
} }
// find our versions // find our versions
vs, err := zfs.ZFSListFilesystemVersions(r.Filesystem, nil) vs, err := zfs.ZFSListFilesystemVersions(r.Filesystem, h.VersionFilter)
if err != nil { if err != nil {
h.Logger.Printf("our versions error: %#v\n", err) h.Logger.Printf("our versions error: %#v\n", err)
return return
@ -99,7 +105,7 @@ func (h Handler) HandleFilesystemVersionsRequest(r *FilesystemVersionsRequest, v
func (h Handler) HandleInitialTransferRequest(r *InitialTransferRequest, stream *io.Reader) (err error) { func (h Handler) HandleInitialTransferRequest(r *InitialTransferRequest, stream *io.Reader) (err error) {
h.Logger.Printf("handling initial transfer request: %#v", r) h.Logger.Printf("handling initial transfer request: %#v", r)
if err = h.pullACLCheck(r.Filesystem); err != nil { if err = h.pullACLCheck(r.Filesystem, &r.FilesystemVersion); err != nil {
return return
} }
@ -118,7 +124,10 @@ func (h Handler) HandleInitialTransferRequest(r *InitialTransferRequest, stream
func (h Handler) HandleIncrementalTransferRequest(r *IncrementalTransferRequest, stream *io.Reader) (err error) { func (h Handler) HandleIncrementalTransferRequest(r *IncrementalTransferRequest, stream *io.Reader) (err error) {
h.Logger.Printf("handling incremental transfer request: %#v", r) h.Logger.Printf("handling incremental transfer request: %#v", r)
if err = h.pullACLCheck(r.Filesystem); err != nil { if err = h.pullACLCheck(r.Filesystem, &r.From); err != nil {
return
}
if err = h.pullACLCheck(r.Filesystem, &r.To); err != nil {
return return
} }
@ -134,18 +143,33 @@ func (h Handler) HandleIncrementalTransferRequest(r *IncrementalTransferRequest,
} }
func (h Handler) pullACLCheck(p *zfs.DatasetPath) (err error) { func (h Handler) pullACLCheck(p *zfs.DatasetPath, v *zfs.FilesystemVersion) (err error) {
var allowed bool var fsAllowed, vAllowed bool
allowed, err = h.PullACL.Filter(p) fsAllowed, err = h.PullACL.Filter(p)
if err != nil { if err != nil {
err = fmt.Errorf("error evaluating ACL: %s", err) err = fmt.Errorf("error evaluating ACL: %s", err)
h.Logger.Printf(err.Error()) h.Logger.Printf(err.Error())
return return
} }
if !allowed { if !fsAllowed {
err = fmt.Errorf("ACL prohibits access to %s", p.ToString()) err = fmt.Errorf("ACL prohibits access to %s", p.ToString())
h.Logger.Printf(err.Error()) h.Logger.Printf(err.Error())
return return
} }
if v == nil {
return
}
vAllowed, err = h.VersionFilter.Filter(*v)
if err != nil {
err = errors.Wrap(err, "error evaluating version filter")
h.Logger.Printf(err.Error())
return
}
if !vAllowed {
err = fmt.Errorf("ACL prohibits access to %s", v.ToAbsPath(p))
h.Logger.Printf(err.Error())
return
}
return return
} }