diff --git a/backend/server/internal/server/middleware.go b/backend/server/internal/server/middleware.go index 0736527..11bd8e1 100644 --- a/backend/server/internal/server/middleware.go +++ b/backend/server/internal/server/middleware.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "io" "net/http" "reflect" "runtime" @@ -50,10 +51,23 @@ func byteCountToString(b int) string { return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) } -type Middleware func(http.HandlerFunc) http.Handler +type Middleware func(http.Handler) http.Handler -func withLogging(s *statsd.Client) Middleware { - return func(h http.HandlerFunc) http.Handler { +// mergeMiddlewares creates a new middleware that runs the given middlewares in reverse order. The first middleware +// passed will be the "outermost" one +func mergeMiddlewares(middlewares ...Middleware) Middleware { + return func(h http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + h = middlewares[i](h) + } + return h + } +} + +// withLogging will log every request made to the wrapped endpoint. It will also log +// panics, but won't stop them. +func withLogging(s *statsd.Client, out io.Writer) Middleware { + return func(h http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { var responseData loggedResponseData lrw := loggingResponseWriter{ @@ -69,10 +83,21 @@ func withLogging(s *statsd.Client) Middleware { ) defer span.Finish() + defer func() { + // log panics + if err := recover(); err != nil { + duration := time.Since(start) + fmt.Fprintf(out, "%s %s %#v %s %s %s %v\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size), err) + + // keep panicking + panic(err) + } + }() + h.ServeHTTP(&lrw, r.WithContext(ctx)) duration := time.Since(start) - fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) + fmt.Fprintf(out, "%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) if s != nil { s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"handler:" + getFunctionName(h)}, 1.0) s.Incr("hishtory.request", []string{"handler:" + getFunctionName(h)}, 1.0) @@ -80,3 +105,19 @@ func withLogging(s *statsd.Client) Middleware { }) } } + +// withPanicGuard is the last defence from a panic. it will log them and return a 500 error +// to the client and prevent the http server from breaking +func withPanicGuard() Middleware { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("panic: %s\n", r) + rw.WriteHeader(http.StatusInternalServerError) + } + }() + h.ServeHTTP(rw, r) + }) + } +} diff --git a/backend/server/internal/server/middleware_test.go b/backend/server/internal/server/middleware_test.go new file mode 100644 index 0000000..1a95fcb --- /dev/null +++ b/backend/server/internal/server/middleware_test.go @@ -0,0 +1,210 @@ +package server + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestLoggerMiddleware(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + var out strings.Builder + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("X-Real-Ip", "127.0.0.1") + logHandler := withLogging(nil, &out)(handler) + logHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected %d, got %d", http.StatusOK, w.Code) + } + const expectedPiece = `127.0.0.1 GET "/"` + if !strings.Contains(out.String(), expectedPiece) { + t.Errorf("expected %q, got %q", expectedPiece, out.String()) + } +} + +func TestLoggerMiddlewareWithPanic(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(fmt.Errorf("oh no")) + }) + + var out strings.Builder + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("X-Real-Ip", "127.0.0.1") + logHandler := withLogging(nil, &out)(handler) + + var panicked bool + var panicError any + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + panicError = r + } + }() + logHandler.ServeHTTP(w, req) + }() + + if !panicked { + t.Errorf("expected panic") + } + // the logger does not write anything if there is a panic, so the response code is the http default of 200 + if w.Code != http.StatusOK { + t.Errorf("expected %d, got %d", http.StatusOK, w.Code) + } + + const expectedPiece1 = `oh no` + const expectedPiece2 = `127.0.0.1 GET "/"` + outString := out.String() + if !strings.Contains(outString, expectedPiece1) { + t.Errorf("expected %q, got %q", expectedPiece1, outString) + } + if !strings.Contains(outString, expectedPiece2) { + t.Errorf("expected %q, got %q", expectedPiece2, outString) + } + + panicStr := fmt.Sprintf("%v", panicError) + if !strings.Contains(panicStr, "oh no") { + t.Errorf("expected panic error to contain %q, got %q", "oh no", panicStr) + } +} + +func TestPanicGuard(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(fmt.Errorf("oh no")) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("X-Real-Ip", "127.0.0.1") + wrappedHandler := withPanicGuard()(handler) + + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + wrappedHandler.ServeHTTP(w, req) + }() + + if panicked { + t.Fatalf("expected no panic") + } + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code) + } +} + +func TestPanicGuardNoPanic(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("X-Real-Ip", "127.0.0.1") + + wrappedHandler := withPanicGuard()(handler) + + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + wrappedHandler.ServeHTTP(w, req) + }() + + if panicked { + t.Fatalf("expected no panic") + } + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } +} + +func TestMergeMiddlewares(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("test")) + }) + panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(fmt.Errorf("oh no")) + }) + + // === + tests := []struct { + name string + handler http.Handler + expectedStatusCode int + expectedPieces []string + }{ + { + name: "no panics", + handler: handler, + expectedStatusCode: http.StatusOK, + expectedPieces: []string{ + `127.0.0.1 GET "/"`, + }, + }, + { + name: "panics", + handler: panicHandler, + expectedStatusCode: http.StatusInternalServerError, + expectedPieces: []string{ + `oh no`, + `127.0.0.1 GET "/"`, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var out strings.Builder + middlewares := mergeMiddlewares( + withPanicGuard(), + withLogging(nil, &out), + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add("X-Real-Ip", "127.0.0.1") + + wrappedHandler := middlewares(test.handler) + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + wrappedHandler.ServeHTTP(w, req) + }() + + if panicked { + t.Fatalf("expected no panic") + } + if w.Code != test.expectedStatusCode { + t.Errorf("expected response status to be %d, got %d", test.expectedStatusCode, w.Code) + } + + for _, expectedPiece := range test.expectedPieces { + if !strings.Contains(out.String(), expectedPiece) { + t.Errorf("expected %q, got %q", expectedPiece, out.String()) + } + } + }) + } +} diff --git a/backend/server/internal/server/srv.go b/backend/server/internal/server/srv.go index d45b4ee..9dbdb05 100644 --- a/backend/server/internal/server/srv.go +++ b/backend/server/internal/server/srv.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "os" "strings" "time" @@ -93,28 +94,31 @@ func (s *Server) Run(ctx context.Context, addr string) error { } }() } - loggerMiddleware := withLogging(s.statsd) + middlewares := mergeMiddlewares( + withPanicGuard(), + withLogging(s.statsd, os.Stdout), + ) - mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) - mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) - mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) - mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) - mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) - mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) - mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) - mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) - mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) - mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) - mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) - mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) - mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) - mux.Handle("/api/v1/ping", loggerMiddleware(s.pingHandler)) - mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) - mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) - mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) + mux.Handle("/api/v1/submit", middlewares(http.HandlerFunc(s.apiSubmitHandler))) + mux.Handle("/api/v1/get-dump-requests", middlewares(http.HandlerFunc(s.apiGetPendingDumpRequestsHandler))) + mux.Handle("/api/v1/submit-dump", middlewares(http.HandlerFunc(s.apiSubmitDumpHandler))) + mux.Handle("/api/v1/query", middlewares(http.HandlerFunc(s.apiQueryHandler))) + mux.Handle("/api/v1/bootstrap", middlewares(http.HandlerFunc(s.apiBootstrapHandler))) + mux.Handle("/api/v1/register", middlewares(http.HandlerFunc(s.apiRegisterHandler))) + mux.Handle("/api/v1/banner", middlewares(http.HandlerFunc(s.apiBannerHandler))) + mux.Handle("/api/v1/download", middlewares(http.HandlerFunc(s.apiDownloadHandler))) + mux.Handle("/api/v1/trigger-cron", middlewares(http.HandlerFunc(s.triggerCronHandler))) + mux.Handle("/api/v1/get-deletion-requests", middlewares(http.HandlerFunc(s.getDeletionRequestsHandler))) + mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler))) + mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler))) + mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler))) + mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler))) + mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler))) + mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler))) + mux.Handle("/internal/api/v1/stats", middlewares(http.HandlerFunc(s.statsHandler))) if s.isTestEnvironment { - mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) - mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) + mux.Handle("/api/v1/wipe-db-entries", middlewares(http.HandlerFunc(s.wipeDbEntriesHandler))) + mux.Handle("/api/v1/get-num-connections", middlewares(http.HandlerFunc(s.getNumConnectionsHandler))) } httpServer := &http.Server{