extract PullACL check into function

This commit is contained in:
Christian Schwarz 2017-08-05 19:40:11 +02:00
parent 4732fdd4cc
commit 3fac6a67df

View File

@ -31,14 +31,12 @@ func (h Handler) HandleFilesystemRequest(r rpc.FilesystemRequest) (roots []zfs.D
func (h Handler) HandleFilesystemVersionsRequest(r rpc.FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error) {
h.Logger.Printf("handling filesystem versions request: %#v", r)
// allowed to request that?
if _, err = h.PullACL.Map(r.Filesystem); err != nil {
h.Logger.Printf("filesystem: %#v\n", r.Filesystem)
h.Logger.Printf("pull mapping: %#v\n", h.PullACL)
h.Logger.Printf("allowed error: %#v\n", err)
if h.pullACLCheck(r.Filesystem); err != nil {
return
}
h.Logger.Printf("allowed: %#v\n", r.Filesystem)
// find our versions
if versions, err = zfs.ZFSListFilesystemVersions(r.Filesystem, nil); err != nil {
@ -54,9 +52,7 @@ func (h Handler) HandleFilesystemVersionsRequest(r rpc.FilesystemVersionsRequest
func (h Handler) HandleInitialTransferRequest(r rpc.InitialTransferRequest) (stream io.Reader, err error) {
h.Logger.Printf("handling initial transfer request: %#v", r)
// allowed to request that?
if _, err = h.PullACL.Map(r.Filesystem); err != nil {
h.Logger.Printf("initial transfer request acl errror: %#v", err)
if err = h.pullACLCheck(r.Filesystem); err != nil {
return
}
@ -73,9 +69,7 @@ func (h Handler) HandleInitialTransferRequest(r rpc.InitialTransferRequest) (str
func (h Handler) HandleIncrementalTransferRequest(r rpc.IncrementalTransferRequest) (stream io.Reader, err error) {
h.Logger.Printf("handling incremental transfer request: %#v", r)
// allowed to request that?
if _, err = h.PullACL.Map(r.Filesystem); err != nil {
h.Logger.Printf("incremental transfer request acl errror: %#v", err)
if err = h.pullACLCheck(r.Filesystem); err != nil {
return
}
@ -121,3 +115,19 @@ func (h Handler) HandlePullMeRequest(r rpc.PullMeRequest, clientIdentity string,
return
}
func (h Handler) pullACLCheck(p zfs.DatasetPath) (err error) {
var allowed bool
allowed, err = h.PullACL.Filter(p)
if err != nil {
err = fmt.Errorf("error evaluating ACL: %s", err)
h.Logger.Printf(err.Error())
return
}
if !allowed {
err = fmt.Errorf("ACL prohibits access to %s", p.ToString())
h.Logger.Printf(err.Error())
return
}
return
}