mirror of
https://github.com/ddworken/hishtory.git
synced 2024-11-27 02:34:06 +01:00
Merge pull request #114 from lsmoura/sergio/panic-guard
Add improve handling for panics in the server
This commit is contained in:
commit
6539f834c7
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"runtime"
|
||||
@ -50,10 +51,23 @@ func byteCountToString(b int) string {
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
|
||||
}
|
||||
|
||||
type Middleware func(http.HandlerFunc) http.Handler
|
||||
type Middleware func(http.Handler) http.Handler
|
||||
|
||||
func withLogging(s *statsd.Client) Middleware {
|
||||
return func(h http.HandlerFunc) http.Handler {
|
||||
// mergeMiddlewares creates a new middleware that runs the given middlewares in reverse order. The first middleware
|
||||
// passed will be the "outermost" one
|
||||
func mergeMiddlewares(middlewares ...Middleware) Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
for i := len(middlewares) - 1; i >= 0; i-- {
|
||||
h = middlewares[i](h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
}
|
||||
|
||||
// withLogging will log every request made to the wrapped endpoint. It will also log
|
||||
// panics, but won't stop them.
|
||||
func withLogging(s *statsd.Client, out io.Writer) Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
var responseData loggedResponseData
|
||||
lrw := loggingResponseWriter{
|
||||
@ -69,10 +83,21 @@ func withLogging(s *statsd.Client) Middleware {
|
||||
)
|
||||
defer span.Finish()
|
||||
|
||||
defer func() {
|
||||
// log panics
|
||||
if err := recover(); err != nil {
|
||||
duration := time.Since(start)
|
||||
fmt.Fprintf(out, "%s %s %#v %s %s %s %v\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size), err)
|
||||
|
||||
// keep panicking
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
h.ServeHTTP(&lrw, r.WithContext(ctx))
|
||||
|
||||
duration := time.Since(start)
|
||||
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
|
||||
fmt.Fprintf(out, "%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
|
||||
if s != nil {
|
||||
s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"handler:" + getFunctionName(h)}, 1.0)
|
||||
s.Incr("hishtory.request", []string{"handler:" + getFunctionName(h)}, 1.0)
|
||||
@ -80,3 +105,19 @@ func withLogging(s *statsd.Client) Middleware {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// withPanicGuard is the last defence from a panic. it will log them and return a 500 error
|
||||
// to the client and prevent the http server from breaking
|
||||
func withPanicGuard() Middleware {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Printf("panic: %s\n", r)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
h.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
210
backend/server/internal/server/middleware_test.go
Normal file
210
backend/server/internal/server/middleware_test.go
Normal file
@ -0,0 +1,210 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoggerMiddleware(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
var out strings.Builder
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("X-Real-Ip", "127.0.0.1")
|
||||
logHandler := withLogging(nil, &out)(handler)
|
||||
logHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
const expectedPiece = `127.0.0.1 GET "/"`
|
||||
if !strings.Contains(out.String(), expectedPiece) {
|
||||
t.Errorf("expected %q, got %q", expectedPiece, out.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerMiddlewareWithPanic(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(fmt.Errorf("oh no"))
|
||||
})
|
||||
|
||||
var out strings.Builder
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("X-Real-Ip", "127.0.0.1")
|
||||
logHandler := withLogging(nil, &out)(handler)
|
||||
|
||||
var panicked bool
|
||||
var panicError any
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
panicError = r
|
||||
}
|
||||
}()
|
||||
logHandler.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
if !panicked {
|
||||
t.Errorf("expected panic")
|
||||
}
|
||||
// the logger does not write anything if there is a panic, so the response code is the http default of 200
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
const expectedPiece1 = `oh no`
|
||||
const expectedPiece2 = `127.0.0.1 GET "/"`
|
||||
outString := out.String()
|
||||
if !strings.Contains(outString, expectedPiece1) {
|
||||
t.Errorf("expected %q, got %q", expectedPiece1, outString)
|
||||
}
|
||||
if !strings.Contains(outString, expectedPiece2) {
|
||||
t.Errorf("expected %q, got %q", expectedPiece2, outString)
|
||||
}
|
||||
|
||||
panicStr := fmt.Sprintf("%v", panicError)
|
||||
if !strings.Contains(panicStr, "oh no") {
|
||||
t.Errorf("expected panic error to contain %q, got %q", "oh no", panicStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPanicGuard(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(fmt.Errorf("oh no"))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("X-Real-Ip", "127.0.0.1")
|
||||
wrappedHandler := withPanicGuard()(handler)
|
||||
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
if panicked {
|
||||
t.Fatalf("expected no panic")
|
||||
}
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPanicGuardNoPanic(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("X-Real-Ip", "127.0.0.1")
|
||||
|
||||
wrappedHandler := withPanicGuard()(handler)
|
||||
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
if panicked {
|
||||
t.Fatalf("expected no panic")
|
||||
}
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeMiddlewares(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(fmt.Errorf("oh no"))
|
||||
})
|
||||
|
||||
// ===
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
expectedStatusCode int
|
||||
expectedPieces []string
|
||||
}{
|
||||
{
|
||||
name: "no panics",
|
||||
handler: handler,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPieces: []string{
|
||||
`127.0.0.1 GET "/"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "panics",
|
||||
handler: panicHandler,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
expectedPieces: []string{
|
||||
`oh no`,
|
||||
`127.0.0.1 GET "/"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
middlewares := mergeMiddlewares(
|
||||
withPanicGuard(),
|
||||
withLogging(nil, &out),
|
||||
)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add("X-Real-Ip", "127.0.0.1")
|
||||
|
||||
wrappedHandler := middlewares(test.handler)
|
||||
var panicked bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
wrappedHandler.ServeHTTP(w, req)
|
||||
}()
|
||||
|
||||
if panicked {
|
||||
t.Fatalf("expected no panic")
|
||||
}
|
||||
if w.Code != test.expectedStatusCode {
|
||||
t.Errorf("expected response status to be %d, got %d", test.expectedStatusCode, w.Code)
|
||||
}
|
||||
|
||||
for _, expectedPiece := range test.expectedPieces {
|
||||
if !strings.Contains(out.String(), expectedPiece) {
|
||||
t.Errorf("expected %q, got %q", expectedPiece, out.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -93,28 +94,31 @@ func (s *Server) Run(ctx context.Context, addr string) error {
|
||||
}
|
||||
}()
|
||||
}
|
||||
loggerMiddleware := withLogging(s.statsd)
|
||||
middlewares := mergeMiddlewares(
|
||||
withPanicGuard(),
|
||||
withLogging(s.statsd, os.Stdout),
|
||||
)
|
||||
|
||||
mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler))
|
||||
mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler))
|
||||
mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler))
|
||||
mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler))
|
||||
mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler))
|
||||
mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler))
|
||||
mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler))
|
||||
mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler))
|
||||
mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler))
|
||||
mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler))
|
||||
mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler))
|
||||
mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler))
|
||||
mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler))
|
||||
mux.Handle("/api/v1/ping", loggerMiddleware(s.pingHandler))
|
||||
mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler))
|
||||
mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler))
|
||||
mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler))
|
||||
mux.Handle("/api/v1/submit", middlewares(http.HandlerFunc(s.apiSubmitHandler)))
|
||||
mux.Handle("/api/v1/get-dump-requests", middlewares(http.HandlerFunc(s.apiGetPendingDumpRequestsHandler)))
|
||||
mux.Handle("/api/v1/submit-dump", middlewares(http.HandlerFunc(s.apiSubmitDumpHandler)))
|
||||
mux.Handle("/api/v1/query", middlewares(http.HandlerFunc(s.apiQueryHandler)))
|
||||
mux.Handle("/api/v1/bootstrap", middlewares(http.HandlerFunc(s.apiBootstrapHandler)))
|
||||
mux.Handle("/api/v1/register", middlewares(http.HandlerFunc(s.apiRegisterHandler)))
|
||||
mux.Handle("/api/v1/banner", middlewares(http.HandlerFunc(s.apiBannerHandler)))
|
||||
mux.Handle("/api/v1/download", middlewares(http.HandlerFunc(s.apiDownloadHandler)))
|
||||
mux.Handle("/api/v1/trigger-cron", middlewares(http.HandlerFunc(s.triggerCronHandler)))
|
||||
mux.Handle("/api/v1/get-deletion-requests", middlewares(http.HandlerFunc(s.getDeletionRequestsHandler)))
|
||||
mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler)))
|
||||
mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler)))
|
||||
mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler)))
|
||||
mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler)))
|
||||
mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler)))
|
||||
mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler)))
|
||||
mux.Handle("/internal/api/v1/stats", middlewares(http.HandlerFunc(s.statsHandler)))
|
||||
if s.isTestEnvironment {
|
||||
mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler))
|
||||
mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler))
|
||||
mux.Handle("/api/v1/wipe-db-entries", middlewares(http.HandlerFunc(s.wipeDbEntriesHandler)))
|
||||
mux.Handle("/api/v1/get-num-connections", middlewares(http.HandlerFunc(s.getNumConnectionsHandler)))
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
|
Loading…
Reference in New Issue
Block a user