restic: refactor to use lib/http

Co-authored-by: Nick Craig-Wood <nick@craig-wood.com>
This commit is contained in:
Nolan Woods 2021-05-02 00:56:24 -07:00 committed by Nick Craig-Wood
parent 4444d2d102
commit 52443c2444
7 changed files with 270 additions and 181 deletions

View File

@ -9,20 +9,22 @@ import (
// cache implements a simple object cache // cache implements a simple object cache
type cache struct { type cache struct {
mu sync.RWMutex // protects the cache mu sync.RWMutex // protects the cache
items map[string]fs.Object // cache of objects items map[string]fs.Object // cache of objects
cacheObjects bool // whether we are actually caching
} }
// create a new cache // create a new cache
func newCache() *cache { func newCache(cacheObjects bool) *cache {
return &cache{ return &cache{
items: map[string]fs.Object{}, items: map[string]fs.Object{},
cacheObjects: cacheObjects,
} }
} }
// find the object at remote or return nil // find the object at remote or return nil
func (c *cache) find(remote string) fs.Object { func (c *cache) find(remote string) fs.Object {
if !cacheObjects { if !c.cacheObjects {
return nil return nil
} }
c.mu.RLock() c.mu.RLock()
@ -33,7 +35,7 @@ func (c *cache) find(remote string) fs.Object {
// add the object to the cache // add the object to the cache
func (c *cache) add(remote string, o fs.Object) { func (c *cache) add(remote string, o fs.Object) {
if !cacheObjects { if !c.cacheObjects {
return return
} }
c.mu.Lock() c.mu.Lock()
@ -43,7 +45,7 @@ func (c *cache) add(remote string, o fs.Object) {
// remove the object from the cache // remove the object from the cache
func (c *cache) remove(remote string) { func (c *cache) remove(remote string) {
if !cacheObjects { if !c.cacheObjects {
return return
} }
c.mu.Lock() c.mu.Lock()
@ -53,7 +55,7 @@ func (c *cache) remove(remote string) {
// remove all the items with prefix from the cache // remove all the items with prefix from the cache
func (c *cache) removePrefix(prefix string) { func (c *cache) removePrefix(prefix string) {
if !cacheObjects { if !c.cacheObjects {
return return
} }

View File

@ -21,7 +21,7 @@ func (c *cache) String() string {
} }
func TestCacheCRUD(t *testing.T) { func TestCacheCRUD(t *testing.T) {
c := newCache() c := newCache(true)
assert.Equal(t, "", c.String()) assert.Equal(t, "", c.String())
assert.Nil(t, c.find("potato")) assert.Nil(t, c.find("potato"))
o := mockobject.New("potato") o := mockobject.New("potato")
@ -35,7 +35,7 @@ func TestCacheCRUD(t *testing.T) {
} }
func TestCacheRemovePrefix(t *testing.T) { func TestCacheRemovePrefix(t *testing.T) {
c := newCache() c := newCache(true)
for _, remote := range []string{ for _, remote := range []string{
"a", "a",
"b", "b",

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"net/http" "net/http"
"os" "os"
"path" "path"
@ -12,34 +13,48 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/rclone/rclone/cmd" "github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/accounting" "github.com/rclone/rclone/fs/accounting"
"github.com/rclone/rclone/fs/config/flags" "github.com/rclone/rclone/fs/config/flags"
"github.com/rclone/rclone/fs/operations" "github.com/rclone/rclone/fs/operations"
"github.com/rclone/rclone/fs/walk" "github.com/rclone/rclone/fs/walk"
libhttp "github.com/rclone/rclone/lib/http"
"github.com/rclone/rclone/lib/http/serve" "github.com/rclone/rclone/lib/http/serve"
"github.com/rclone/rclone/lib/terminal" "github.com/rclone/rclone/lib/terminal"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
var ( // Options required for http server
stdio bool type Options struct {
appendOnly bool Auth libhttp.AuthConfig
privateRepos bool HTTP libhttp.Config
cacheObjects bool Stdio bool
) AppendOnly bool
PrivateRepos bool
CacheObjects bool
}
// DefaultOpt is the default values used for Options
var DefaultOpt = Options{
Auth: libhttp.DefaultAuthCfg(),
HTTP: libhttp.DefaultCfg(),
}
// Opt is options set by command line flags
var Opt = DefaultOpt
func init() { func init() {
httpflags.AddFlags(Command.Flags())
flagSet := Command.Flags() flagSet := Command.Flags()
flags.BoolVarP(flagSet, &stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout") libhttp.AddAuthFlagsPrefix(flagSet, "", &Opt.Auth)
flags.BoolVarP(flagSet, &appendOnly, "append-only", "", false, "Disallow deletion of repository data") libhttp.AddHTTPFlagsPrefix(flagSet, "", &Opt.HTTP)
flags.BoolVarP(flagSet, &privateRepos, "private-repos", "", false, "Users can only access their private repo") flags.BoolVarP(flagSet, &Opt.Stdio, "stdio", "", false, "Run an HTTP2 server on stdin/stdout")
flags.BoolVarP(flagSet, &cacheObjects, "cache-objects", "", true, "Cache listed objects") flags.BoolVarP(flagSet, &Opt.AppendOnly, "append-only", "", false, "Disallow deletion of repository data")
flags.BoolVarP(flagSet, &Opt.PrivateRepos, "private-repos", "", false, "Users can only access their private repo")
flags.BoolVarP(flagSet, &Opt.CacheObjects, "cache-objects", "", true, "Cache listed objects")
} }
// Command definition for cobra // Command definition for cobra
@ -127,16 +142,21 @@ these **must** end with /. Eg
The` + "`--private-repos`" + ` flag can be used to limit users to repositories starting The` + "`--private-repos`" + ` flag can be used to limit users to repositories starting
with a path of ` + "`/<username>/`" + `. with a path of ` + "`/<username>/`" + `.
` + httplib.Help, ` + libhttp.Help + libhttp.AuthHelp,
Annotations: map[string]string{ Annotations: map[string]string{
"versionIntroduced": "v1.40", "versionIntroduced": "v1.40",
}, },
Run: func(command *cobra.Command, args []string) { Run: func(command *cobra.Command, args []string) {
ctx := context.Background()
cmd.CheckArgs(1, 1, command, args) cmd.CheckArgs(1, 1, command, args)
f := cmd.NewFsSrc(args) f := cmd.NewFsSrc(args)
cmd.Run(false, true, command, func() error { cmd.Run(false, true, command, func() error {
s := NewServer(f, &httpflags.Opt) s, err := newServer(ctx, f, &Opt)
if stdio { if err != nil {
return err
}
fs.Logf(s.f, "Serving restic REST API on %s", s.URLs())
if s.opt.Stdio {
if terminal.IsTerminal(int(os.Stdout.Fd())) { if terminal.IsTerminal(int(os.Stdout.Fd())) {
return errors.New("refusing to run HTTP2 server directly on a terminal, please let restic start rclone") return errors.New("refusing to run HTTP2 server directly on a terminal, please let restic start rclone")
} }
@ -148,16 +168,11 @@ with a path of ` + "`/<username>/`" + `.
httpSrv := &http2.Server{} httpSrv := &http2.Server{}
opts := &http2.ServeConnOpts{ opts := &http2.ServeConnOpts{
Handler: s, Handler: s.Server.Router(),
} }
httpSrv.ServeConn(conn, opts) httpSrv.ServeConn(conn, opts)
return nil return nil
} }
err := s.Serve()
if err != nil {
return err
}
s.Wait()
return nil return nil
}) })
}, },
@ -167,101 +182,130 @@ const (
resticAPIV2 = "application/vnd.x.restic.rest.v2" resticAPIV2 = "application/vnd.x.restic.rest.v2"
) )
// Server contains everything to run the Server type contextRemoteType struct{}
type Server struct {
*httplib.Server // ContextRemoteKey is a simple context key for storing the username of the request
var ContextRemoteKey = &contextRemoteType{}
// WithRemote makes a remote from a URL path. This implements the backend layout
// required by restic.
func WithRemote(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var urlpath string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
urlpath = rctx.RoutePath
} else {
urlpath = r.URL.Path
}
urlpath = strings.Trim(urlpath, "/")
parts := matchData.FindStringSubmatch(urlpath)
// if no data directory, layout is flat
if parts != nil {
// otherwise map
// data/2159dd48 to
// data/21/2159dd48
fileName := parts[1]
prefix := urlpath[:len(urlpath)-len(fileName)]
urlpath = prefix + fileName[:2] + "/" + fileName
}
ctx := context.WithValue(r.Context(), ContextRemoteKey, urlpath)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Middleware to ensure authenticated user is accessing their own private folder
func checkPrivate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := chi.URLParam(r, "userID")
userID, ok := libhttp.CtxGetUser(r.Context())
if ok && user != "" && user == userID {
next.ServeHTTP(w, r)
} else {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
})
}
// server contains everything to run the server
type server struct {
*libhttp.Server
f fs.Fs f fs.Fs
cache *cache cache *cache
opt Options
} }
// NewServer returns an HTTP server that speaks the rest protocol func newServer(ctx context.Context, f fs.Fs, opt *Options) (s *server, err error) {
func NewServer(f fs.Fs, opt *httplib.Options) *Server { s = &server{
mux := http.NewServeMux() f: f,
s := &Server{ cache: newCache(opt.CacheObjects),
Server: httplib.NewServer(mux, opt), opt: *opt,
f: f,
cache: newCache(),
} }
mux.HandleFunc(s.Opt.BaseURL+"/", s.ServeHTTP) s.Server, err = libhttp.NewServer(ctx,
return s libhttp.WithConfig(opt.HTTP),
} libhttp.WithAuth(opt.Auth),
)
// Serve runs the http server in the background.
//
// Use s.Close() and s.Wait() to shutdown server
func (s *Server) Serve() error {
err := s.Server.Serve()
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to init server: %w", err)
}
router := s.Router()
s.Bind(router)
s.Server.Serve()
return s, nil
}
// bind helper for main Bind method
func (s *server) bind(router chi.Router) {
router.MethodFunc("GET", "/*", func(w http.ResponseWriter, r *http.Request) {
urlpath := chi.URLParam(r, "*")
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
s.listObjects(w, r)
} else {
s.serveObject(w, r)
}
})
router.MethodFunc("POST", "/*", func(w http.ResponseWriter, r *http.Request) {
urlpath := chi.URLParam(r, "*")
if urlpath == "" || strings.HasSuffix(urlpath, "/") {
s.createRepo(w, r)
} else {
s.postObject(w, r)
}
})
router.MethodFunc("HEAD", "/*", s.serveObject)
router.MethodFunc("DELETE", "/*", s.deleteObject)
}
// Bind restic server routes to passed router
func (s *server) Bind(router chi.Router) {
// FIXME
// if m := authX.Auth(authX.Opt); m != nil {
// router.Use(m)
// }
router.Use(
middleware.SetHeader("Accept-Ranges", "bytes"),
middleware.SetHeader("Server", "rclone/"+fs.Version),
WithRemote,
)
if s.opt.PrivateRepos {
router.Route("/{userID}", func(r chi.Router) {
r.Use(checkPrivate)
s.bind(r)
})
router.NotFound(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
})
} else {
s.bind(router)
} }
fs.Logf(s.f, "Serving restic REST API on %s", s.URL())
return nil
} }
var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$") var matchData = regexp.MustCompile("(?:^|/)data/([^/]{2,})$")
// Makes a remote from a URL path. This implements the backend layout
// required by restic.
func makeRemote(path string) string {
path = strings.Trim(path, "/")
parts := matchData.FindStringSubmatch(path)
// if no data directory, layout is flat
if parts == nil {
return path
}
// otherwise map
// data/2159dd48 to
// data/21/2159dd48
fileName := parts[1]
prefix := path[:len(path)-len(fileName)]
return prefix + fileName[:2] + "/" + fileName
}
// ServeHTTP reads incoming requests and dispatches them
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Server", "rclone/"+fs.Version)
path, ok := s.Path(w, r)
if !ok {
return
}
remote := makeRemote(path)
fs.Debugf(s.f, "%s %s", r.Method, path)
v := r.Context().Value(httplib.ContextUserKey)
if privateRepos && (v == nil || !strings.HasPrefix(path, "/"+v.(string)+"/")) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// Dispatch on path then method
if strings.HasSuffix(path, "/") {
switch r.Method {
case "GET":
s.listObjects(w, r, remote)
case "POST":
s.createRepo(w, r, remote)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
}
} else {
switch r.Method {
case "GET", "HEAD":
s.serveObject(w, r, remote)
case "POST":
s.postObject(w, r, remote)
case "DELETE":
s.deleteObject(w, r, remote)
default:
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
}
}
}
// newObject returns an object with the remote given either from the // newObject returns an object with the remote given either from the
// cache or directly // cache or directly
func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error) { func (s *server) newObject(ctx context.Context, remote string) (fs.Object, error) {
o := s.cache.find(remote) o := s.cache.find(remote)
if o != nil { if o != nil {
return o, nil return o, nil
@ -275,7 +319,12 @@ func (s *Server) newObject(ctx context.Context, remote string) (fs.Object, error
} }
// get the remote // get the remote
func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote string) { func (s *server) serveObject(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
o, err := s.newObject(r.Context(), remote) o, err := s.newObject(r.Context(), remote)
if err != nil { if err != nil {
fs.Debugf(remote, "%s request error: %v", r.Method, err) fs.Debugf(remote, "%s request error: %v", r.Method, err)
@ -286,8 +335,13 @@ func (s *Server) serveObject(w http.ResponseWriter, r *http.Request, remote stri
} }
// postObject posts an object to the repository // postObject posts an object to the repository
func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote string) { func (s *server) postObject(w http.ResponseWriter, r *http.Request) {
if appendOnly { remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.opt.AppendOnly {
// make sure the file does not exist yet // make sure the file does not exist yet
_, err := s.newObject(r.Context(), remote) _, err := s.newObject(r.Context(), remote)
if err == nil { if err == nil {
@ -312,8 +366,13 @@ func (s *Server) postObject(w http.ResponseWriter, r *http.Request, remote strin
} }
// delete the remote // delete the remote
func (s *Server) deleteObject(w http.ResponseWriter, r *http.Request, remote string) { func (s *server) deleteObject(w http.ResponseWriter, r *http.Request) {
if appendOnly { remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if s.opt.AppendOnly {
parts := strings.Split(r.URL.Path, "/") parts := strings.Split(r.URL.Path, "/")
// if path doesn't end in "/locks/:name", disallow the operation // if path doesn't end in "/locks/:name", disallow the operation
@ -362,14 +421,18 @@ func (ls *listItems) add(o fs.Object) {
} }
// listObjects lists all Objects of a given type in an arbitrary order. // listObjects lists all Objects of a given type in an arbitrary order.
func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote string) { func (s *server) listObjects(w http.ResponseWriter, r *http.Request) {
fs.Debugf(remote, "list request") remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
if r.Header.Get("Accept") != resticAPIV2 { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
fs.Errorf(remote, "Restic v2 API required")
http.Error(w, "Restic v2 API required", http.StatusBadRequest)
return return
} }
if r.Header.Get("Accept") != resticAPIV2 {
fs.Errorf(remote, "Restic v2 API required for List Objects")
http.Error(w, "Restic v2 API required for List Objects", http.StatusBadRequest)
return
}
fs.Debugf(remote, "list request")
// make sure an empty list is returned, and not a 'nil' value // make sure an empty list is returned, and not a 'nil' value
ls := listItems{} ls := listItems{}
@ -408,7 +471,12 @@ func (s *Server) listObjects(w http.ResponseWriter, r *http.Request, remote stri
// createRepo creates repository directories. // createRepo creates repository directories.
// //
// We don't bother creating the data dirs as rclone will create them on the fly // We don't bother creating the data dirs as rclone will create them on the fly
func (s *Server) createRepo(w http.ResponseWriter, r *http.Request, remote string) { func (s *server) createRepo(w http.ResponseWriter, r *http.Request) {
remote, ok := r.Context().Value(ContextRemoteKey).(string)
if !ok {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
fs.Infof(remote, "Creating repository") fs.Infof(remote, "Creating repository")
if r.URL.Query().Get("create") != "true" { if r.URL.Query().Get("create") != "true" {

View File

@ -1,6 +1,7 @@
package restic package restic
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"io" "io"
@ -9,7 +10,6 @@ import (
"testing" "testing"
"github.com/rclone/rclone/cmd" "github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/rclone/rclone/fs/config/configfile" "github.com/rclone/rclone/fs/config/configfile"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -62,6 +62,7 @@ func createOverwriteDeleteSeq(t testing.TB, path string) []TestRequest {
// TestResticHandler runs tests on the restic handler code, especially in append-only mode. // TestResticHandler runs tests on the restic handler code, especially in append-only mode.
func TestResticHandler(t *testing.T) { func TestResticHandler(t *testing.T) {
ctx := context.Background()
configfile.Install() configfile.Install()
buf := make([]byte, 32) buf := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, buf) _, err := io.ReadFull(rand.Reader, buf)
@ -110,19 +111,18 @@ func TestResticHandler(t *testing.T) {
// setup rclone with a local backend in a temporary directory // setup rclone with a local backend in a temporary directory
tempdir := t.TempDir() tempdir := t.TempDir()
// globally set append-only mode // set append-only mode
prev := appendOnly opt := newOpt()
appendOnly = true opt.AppendOnly = true
defer func() {
appendOnly = prev // reset when done
}()
// make a new file system in the temp dir // make a new file system in the temp dir
f := cmd.NewFsSrc([]string{tempdir}) f := cmd.NewFsSrc([]string{tempdir})
srv := NewServer(f, &httpflags.Opt) s, err := newServer(ctx, f, &opt)
require.NoError(t, err)
router := s.Server.Router()
// create the repo // create the repo
checkRequest(t, srv.ServeHTTP, checkRequest(t, router.ServeHTTP,
newRequest(t, "POST", "/?create=true", nil), newRequest(t, "POST", "/?create=true", nil),
[]wantFunc{wantCode(http.StatusOK)}) []wantFunc{wantCode(http.StatusOK)})
@ -130,7 +130,7 @@ func TestResticHandler(t *testing.T) {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
for i, seq := range test.seq { for i, seq := range test.seq {
t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path) t.Logf("request %v: %v %v", i, seq.req.Method, seq.req.URL.Path)
checkRequest(t, srv.ServeHTTP, seq.req, seq.want) checkRequest(t, router.ServeHTTP, seq.req, seq.want)
} }
}) })
} }

View File

@ -8,23 +8,21 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/cmd" "github.com/rclone/rclone/cmd"
"github.com/rclone/rclone/cmd/serve/httplib/httpflags"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// newAuthenticatedRequest returns a new HTTP request with the given params. // newAuthenticatedRequest returns a new HTTP request with the given params.
func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader) *http.Request { func newAuthenticatedRequest(t testing.TB, method, path string, body io.Reader, user, pass string) *http.Request {
req := newRequest(t, method, path, body) req := newRequest(t, method, path, body)
req = req.WithContext(context.WithValue(req.Context(), httplib.ContextUserKey, "test")) req.SetBasicAuth(user, pass)
req.Header.Add("Accept", resticAPIV2) req.Header.Add("Accept", resticAPIV2)
return req return req
} }
// TestResticPrivateRepositories runs tests on the restic handler code for private repositories // TestResticPrivateRepositories runs tests on the restic handler code for private repositories
func TestResticPrivateRepositories(t *testing.T) { func TestResticPrivateRepositories(t *testing.T) {
ctx := context.Background()
buf := make([]byte, 32) buf := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, buf) _, err := io.ReadFull(rand.Reader, buf)
require.NoError(t, err) require.NoError(t, err)
@ -32,42 +30,49 @@ func TestResticPrivateRepositories(t *testing.T) {
// setup rclone with a local backend in a temporary directory // setup rclone with a local backend in a temporary directory
tempdir := t.TempDir() tempdir := t.TempDir()
// globally set private-repos mode & test user opt := newOpt()
prev := privateRepos
prevUser := httpflags.Opt.BasicUser // set private-repos mode & test user
prevPassword := httpflags.Opt.BasicPass opt.PrivateRepos = true
privateRepos = true opt.Auth.BasicUser = "test"
httpflags.Opt.BasicUser = "test" opt.Auth.BasicPass = "password"
httpflags.Opt.BasicPass = "password"
// reset when done
defer func() {
privateRepos = prev
httpflags.Opt.BasicUser = prevUser
httpflags.Opt.BasicPass = prevPassword
}()
// make a new file system in the temp dir // make a new file system in the temp dir
f := cmd.NewFsSrc([]string{tempdir}) f := cmd.NewFsSrc([]string{tempdir})
srv := NewServer(f, &httpflags.Opt) s, err := newServer(ctx, f, &opt)
require.NoError(t, err)
router := s.Server.Router()
// Requesting /test/ should allow access // Requesting /test/ should allow access
reqs := []*http.Request{ reqs := []*http.Request{
newAuthenticatedRequest(t, "POST", "/test/?create=true", nil), newAuthenticatedRequest(t, "POST", "/test/?create=true", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config")), newAuthenticatedRequest(t, "POST", "/test/config", strings.NewReader("foobar test config"), opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil), newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
} }
for _, req := range reqs { for _, req := range reqs {
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)}) checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusOK)})
}
// Requesting with bad credentials should raise unauthorised errors
reqs = []*http.Request{
newRequest(t, "GET", "/test/config", nil),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, ""),
newAuthenticatedRequest(t, "GET", "/test/config", nil, "", opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser+"x", opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/test/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass+"x"),
}
for _, req := range reqs {
checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusUnauthorized)})
} }
// Requesting everything else should raise forbidden errors // Requesting everything else should raise forbidden errors
reqs = []*http.Request{ reqs = []*http.Request{
newAuthenticatedRequest(t, "GET", "/", nil), newAuthenticatedRequest(t, "GET", "/", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "POST", "/other_user", nil), newAuthenticatedRequest(t, "POST", "/other_user", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
newAuthenticatedRequest(t, "GET", "/other_user/config", nil), newAuthenticatedRequest(t, "GET", "/other_user/config", nil, opt.Auth.BasicUser, opt.Auth.BasicPass),
} }
for _, req := range reqs { for _, req := range reqs {
checkRequest(t, srv.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)}) checkRequest(t, router.ServeHTTP, req, []wantFunc{wantCode(http.StatusForbidden)})
} }
} }

View File

@ -5,14 +5,16 @@ package restic
import ( import (
"context" "context"
"net/http"
"net/http/httptest"
"os" "os"
"os/exec" "os/exec"
"testing" "testing"
_ "github.com/rclone/rclone/backend/all" _ "github.com/rclone/rclone/backend/all"
"github.com/rclone/rclone/cmd/serve/httplib"
"github.com/rclone/rclone/fstest" "github.com/rclone/rclone/fstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -20,16 +22,24 @@ const (
resticSource = "../../../../../restic/restic" resticSource = "../../../../../restic/restic"
) )
func newOpt() Options {
opt := DefaultOpt
opt.HTTP.ListenAddr = []string{testBindAddress}
return opt
}
// TestRestic runs the restic server then runs the unit tests for the // TestRestic runs the restic server then runs the unit tests for the
// restic remote against it. // restic remote against it.
func TestRestic(t *testing.T) { //
// Requires the restic source code in the location indicated by resticSource.
func TestResticIntegration(t *testing.T) {
ctx := context.Background()
_, err := os.Stat(resticSource) _, err := os.Stat(resticSource)
if err != nil { if err != nil {
t.Skipf("Skipping test as restic source not found: %v", err) t.Skipf("Skipping test as restic source not found: %v", err)
} }
opt := httplib.DefaultOpt opt := newOpt()
opt.ListenAddr = testBindAddress
fstest.Initialise() fstest.Initialise()
@ -41,16 +51,16 @@ func TestRestic(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// Start the server // Start the server
w := NewServer(fremote, &opt) s, err := newServer(ctx, fremote, &opt)
assert.NoError(t, w.Serve()) require.NoError(t, err)
testURL := s.Server.URLs()[0]
defer func() { defer func() {
w.Close() _ = s.Shutdown()
w.Wait()
}() }()
// Change directory to run the tests // Change directory to run the tests
err = os.Chdir(resticSource) err = os.Chdir(resticSource)
assert.NoError(t, err, "failed to cd to restic source code") require.NoError(t, err, "failed to cd to restic source code")
// Run the restic tests // Run the restic tests
runTests := func(path string) { runTests := func(path string) {
@ -60,7 +70,7 @@ func TestRestic(t *testing.T) {
} }
cmd := exec.Command("go", args...) cmd := exec.Command("go", args...)
cmd.Env = append(os.Environ(), cmd.Env = append(os.Environ(),
"RESTIC_TEST_REST_REPOSITORY=rest:"+w.Server.URL()+path, "RESTIC_TEST_REST_REPOSITORY=rest:"+testURL+path,
"GO111MODULE=on", "GO111MODULE=on",
) )
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
@ -81,7 +91,6 @@ func TestMakeRemote(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
in, want string in, want string
}{ }{
{"", ""},
{"/", ""}, {"/", ""},
{"/data", "data"}, {"/data", "data"},
{"/data/", "data"}, {"/data/", "data"},
@ -94,7 +103,14 @@ func TestMakeRemote(t *testing.T) {
{"/keys/12", "keys/12"}, {"/keys/12", "keys/12"},
{"/keys/123", "keys/123"}, {"/keys/123", "keys/123"},
} { } {
got := makeRemote(test.in) r := httptest.NewRequest("GET", test.in, nil)
assert.Equal(t, test.want, got, test.in) w := httptest.NewRecorder()
next := http.HandlerFunc(func(_ http.ResponseWriter, request *http.Request) {
remote, ok := request.Context().Value(ContextRemoteKey).(string)
assert.True(t, ok, "Failed to get remote from context")
assert.Equal(t, test.want, remote, test.in)
})
got := WithRemote(next)
got.ServeHTTP(w, r)
} }
} }

View File

@ -7,7 +7,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// declare a few helper functions // declare a few helper functions
@ -15,11 +14,10 @@ import (
// wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect. // wantFunc tests the HTTP response in res and marks the test as errored if something is incorrect.
type wantFunc func(t testing.TB, res *httptest.ResponseRecorder) type wantFunc func(t testing.TB, res *httptest.ResponseRecorder)
// newRequest returns a new HTTP request with the given params. On error, the // newRequest returns a new HTTP request with the given params
// test is marked as failed.
func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request { func newRequest(t testing.TB, method, path string, body io.Reader) *http.Request {
req, err := http.NewRequest(method, path, body) req := httptest.NewRequest(method, path, body)
require.NoError(t, err) req.Header.Add("Accept", resticAPIV2)
return req return req
} }