mirror of
https://github.com/rclone/rclone.git
synced 2024-11-07 09:04:52 +01:00
lib/http: Simplify server.go to export an http server rather than an interface
This also makes the implementation public.
This commit is contained in:
parent
2a2fcf1012
commit
ec7cc2b3c3
@ -103,7 +103,7 @@ control the stats printing.
|
||||
type serveCmd struct {
|
||||
f fs.Fs
|
||||
vfs *vfs.VFS
|
||||
server libhttp.Server
|
||||
server *libhttp.Server
|
||||
}
|
||||
|
||||
func run(ctx context.Context, f fs.Fs, opt Options) (*serveCmd, error) {
|
||||
|
@ -3,7 +3,7 @@ package serve
|
||||
import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -94,7 +94,7 @@ func TestError(t *testing.T) {
|
||||
Error("potato", w, "sausage", err)
|
||||
resp := w.Result()
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "sausage.\n", string(body))
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ func TestServe(t *testing.T) {
|
||||
d.Serve(w, r)
|
||||
resp := w.Result()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
|
@ -1,7 +1,7 @@
|
||||
package serve
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -17,7 +17,7 @@ func TestObjectBadMethod(t *testing.T) {
|
||||
Object(w, r, o)
|
||||
resp := w.Result()
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "Method Not Allowed\n", string(body))
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ func TestObjectHEAD(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "5", resp.Header.Get("Content-Length"))
|
||||
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "", string(body))
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ func TestObjectGET(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "5", resp.Header.Get("Content-Length"))
|
||||
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "hello", string(body))
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ func TestObjectRange(t *testing.T) {
|
||||
assert.Equal(t, "3", resp.Header.Get("Content-Length"))
|
||||
assert.Equal(t, "bytes", resp.Header.Get("Accept-Ranges"))
|
||||
assert.Equal(t, "bytes 3-5/10", resp.Header.Get("Content-Range"))
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "345", string(body))
|
||||
}
|
||||
|
||||
@ -71,6 +71,6 @@ func TestObjectBadRange(t *testing.T) {
|
||||
resp := w.Result()
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.Equal(t, "10", resp.Header.Get("Content-Length"))
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "Bad Request\n", string(body))
|
||||
}
|
||||
|
@ -121,16 +121,6 @@ func DefaultCfg() Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Server interface of http server
|
||||
type Server interface {
|
||||
Router() chi.Router
|
||||
Serve()
|
||||
Shutdown() error
|
||||
HTMLTemplate() *template.Template
|
||||
URLs() []string
|
||||
Wait()
|
||||
}
|
||||
|
||||
type instance struct {
|
||||
url string
|
||||
listener net.Listener
|
||||
@ -145,7 +135,8 @@ func (s instance) serve(wg *sync.WaitGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
type server struct {
|
||||
// Server contains info about the running http server
|
||||
type Server struct {
|
||||
wg sync.WaitGroup
|
||||
mux chi.Router
|
||||
tlsConfig *tls.Config
|
||||
@ -157,25 +148,25 @@ type server struct {
|
||||
}
|
||||
|
||||
// Option allows customizing the server
|
||||
type Option func(*server)
|
||||
type Option func(*Server)
|
||||
|
||||
// WithAuth option initializes the appropriate auth middleware
|
||||
func WithAuth(cfg AuthConfig) Option {
|
||||
return func(s *server) {
|
||||
return func(s *Server) {
|
||||
s.auth = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfig option applies the Config to the server, overriding defaults
|
||||
func WithConfig(cfg Config) Option {
|
||||
return func(s *server) {
|
||||
return func(s *Server) {
|
||||
s.cfg = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithTemplate option allows the parsing of a template
|
||||
func WithTemplate(cfg TemplateConfig) Option {
|
||||
return func(s *server) {
|
||||
return func(s *Server) {
|
||||
s.template = &cfg
|
||||
}
|
||||
}
|
||||
@ -184,12 +175,17 @@ func WithTemplate(cfg TemplateConfig) Option {
|
||||
// This function is provided if the default http server does not meet a services requirements and should not generally be used
|
||||
// A http server can listen using multiple listeners. For example, a listener for port 80, and a listener for port 443.
|
||||
// tlsListeners are ignored if opt.TLSKey is not provided
|
||||
func NewServer(ctx context.Context, options ...Option) (*server, error) {
|
||||
s := &server{
|
||||
func NewServer(ctx context.Context, options ...Option) (*Server, error) {
|
||||
s := &Server{
|
||||
mux: chi.NewRouter(),
|
||||
cfg: DefaultCfg(),
|
||||
}
|
||||
|
||||
// Make sure default logger is logging where everything else is
|
||||
// middleware.DefaultLogger = middleware.RequestLogger(&middleware.DefaultLogFormatter{Logger: log.Default(), NoColor: true})
|
||||
// Log requests
|
||||
// s.mux.Use(middleware.Logger)
|
||||
|
||||
for _, opt := range options {
|
||||
opt(s)
|
||||
}
|
||||
@ -275,7 +271,7 @@ func NewServer(ctx context.Context, options ...Option) (*server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) initAuth() {
|
||||
func (s *Server) initAuth() {
|
||||
if s.auth.CustomAuthFn != nil {
|
||||
s.mux.Use(MiddlewareAuthCustom(s.auth.CustomAuthFn, s.auth.Realm))
|
||||
return
|
||||
@ -292,7 +288,7 @@ func (s *server) initAuth() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) initTemplate() error {
|
||||
func (s *Server) initTemplate() error {
|
||||
if s.template == nil {
|
||||
return nil
|
||||
}
|
||||
@ -317,7 +313,7 @@ var (
|
||||
ErrTLSParseCA = errors.New("unable to parse client certificate authority")
|
||||
)
|
||||
|
||||
func (s *server) initTLS() error {
|
||||
func (s *Server) initTLS() error {
|
||||
if s.cfg.TLSCert == "" && s.cfg.TLSKey == "" && len(s.cfg.TLSCertBody) == 0 && len(s.cfg.TLSKeyBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -383,7 +379,7 @@ func (s *server) initTLS() error {
|
||||
}
|
||||
|
||||
// Serve starts the HTTP server on each listener
|
||||
func (s *server) Serve() {
|
||||
func (s *Server) Serve() {
|
||||
s.wg.Add(len(s.instances))
|
||||
for _, ii := range s.instances {
|
||||
// TODO: decide how/when to log listening url
|
||||
@ -393,17 +389,17 @@ func (s *server) Serve() {
|
||||
}
|
||||
|
||||
// Wait blocks while the server is serving requests
|
||||
func (s *server) Wait() {
|
||||
func (s *Server) Wait() {
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
// Router returns the server base router
|
||||
func (s *server) Router() chi.Router {
|
||||
func (s *Server) Router() chi.Router {
|
||||
return s.mux
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server
|
||||
func (s *server) Shutdown() error {
|
||||
func (s *Server) Shutdown() error {
|
||||
ctx := context.Background()
|
||||
for _, ii := range s.instances {
|
||||
if err := ii.httpServer.Shutdown(ctx); err != nil {
|
||||
@ -416,12 +412,12 @@ func (s *server) Shutdown() error {
|
||||
}
|
||||
|
||||
// HTMLTemplate returns the parsed template, if WithTemplate option was passed.
|
||||
func (s *server) HTMLTemplate() *template.Template {
|
||||
func (s *Server) HTMLTemplate() *template.Template {
|
||||
return s.htmlTemplate
|
||||
}
|
||||
|
||||
// URLs returns all configured URLS
|
||||
func (s *server) URLs() []string {
|
||||
func (s *Server) URLs() []string {
|
||||
var out []string
|
||||
for _, ii := range s.instances {
|
||||
if ii.listener.Addr().Network() == "unix" {
|
||||
|
@ -26,7 +26,7 @@ func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) {
|
||||
require.Equal(t, expected, body)
|
||||
}
|
||||
|
||||
func testGetServerURL(t *testing.T, s Server) string {
|
||||
func testGetServerURL(t *testing.T, s *Server) string {
|
||||
urls := s.URLs()
|
||||
require.GreaterOrEqual(t, len(urls), 1, "server should return at least one url")
|
||||
return urls[0]
|
||||
|
Loading…
Reference in New Issue
Block a user