Merge pull request #114 from lsmoura/sergio/panic-guard

Add improve handling for panics in the server
This commit is contained in:
David Dworken 2023-09-30 07:24:01 -07:00 committed by GitHub
commit 6539f834c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 279 additions and 24 deletions

View File

@ -2,6 +2,7 @@ package server
import (
"fmt"
"io"
"net/http"
"reflect"
"runtime"
@ -50,10 +51,23 @@ func byteCountToString(b int) string {
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
}
type Middleware func(http.HandlerFunc) http.Handler
type Middleware func(http.Handler) http.Handler
func withLogging(s *statsd.Client) Middleware {
return func(h http.HandlerFunc) http.Handler {
// mergeMiddlewares creates a new middleware that runs the given middlewares in reverse order. The first middleware
// passed will be the "outermost" one
func mergeMiddlewares(middlewares ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
h = middlewares[i](h)
}
return h
}
}
// withLogging will log every request made to the wrapped endpoint. It will also log
// panics, but won't stop them.
func withLogging(s *statsd.Client, out io.Writer) Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var responseData loggedResponseData
lrw := loggingResponseWriter{
@ -69,10 +83,21 @@ func withLogging(s *statsd.Client) Middleware {
)
defer span.Finish()
defer func() {
// log panics
if err := recover(); err != nil {
duration := time.Since(start)
fmt.Fprintf(out, "%s %s %#v %s %s %s %v\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size), err)
// keep panicking
panic(err)
}
}()
h.ServeHTTP(&lrw, r.WithContext(ctx))
duration := time.Since(start)
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
fmt.Fprintf(out, "%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
if s != nil {
s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"handler:" + getFunctionName(h)}, 1.0)
s.Incr("hishtory.request", []string{"handler:" + getFunctionName(h)}, 1.0)
@ -80,3 +105,19 @@ func withLogging(s *statsd.Client) Middleware {
})
}
}
// withPanicGuard is the last defence from a panic. it will log them and return a 500 error
// to the client and prevent the http server from breaking
func withPanicGuard() Middleware {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
fmt.Printf("panic: %s\n", r)
rw.WriteHeader(http.StatusInternalServerError)
}
}()
h.ServeHTTP(rw, r)
})
}
}

View File

@ -0,0 +1,210 @@
package server
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestLoggerMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
})
var out strings.Builder
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("X-Real-Ip", "127.0.0.1")
logHandler := withLogging(nil, &out)(handler)
logHandler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
}
const expectedPiece = `127.0.0.1 GET "/"`
if !strings.Contains(out.String(), expectedPiece) {
t.Errorf("expected %q, got %q", expectedPiece, out.String())
}
}
func TestLoggerMiddlewareWithPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(fmt.Errorf("oh no"))
})
var out strings.Builder
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("X-Real-Ip", "127.0.0.1")
logHandler := withLogging(nil, &out)(handler)
var panicked bool
var panicError any
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
panicError = r
}
}()
logHandler.ServeHTTP(w, req)
}()
if !panicked {
t.Errorf("expected panic")
}
// the logger does not write anything if there is a panic, so the response code is the http default of 200
if w.Code != http.StatusOK {
t.Errorf("expected %d, got %d", http.StatusOK, w.Code)
}
const expectedPiece1 = `oh no`
const expectedPiece2 = `127.0.0.1 GET "/"`
outString := out.String()
if !strings.Contains(outString, expectedPiece1) {
t.Errorf("expected %q, got %q", expectedPiece1, outString)
}
if !strings.Contains(outString, expectedPiece2) {
t.Errorf("expected %q, got %q", expectedPiece2, outString)
}
panicStr := fmt.Sprintf("%v", panicError)
if !strings.Contains(panicStr, "oh no") {
t.Errorf("expected panic error to contain %q, got %q", "oh no", panicStr)
}
}
func TestPanicGuard(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(fmt.Errorf("oh no"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("X-Real-Ip", "127.0.0.1")
wrappedHandler := withPanicGuard()(handler)
var panicked bool
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
wrappedHandler.ServeHTTP(w, req)
}()
if panicked {
t.Fatalf("expected no panic")
}
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
}
func TestPanicGuardNoPanic(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("X-Real-Ip", "127.0.0.1")
wrappedHandler := withPanicGuard()(handler)
var panicked bool
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
wrappedHandler.ServeHTTP(w, req)
}()
if panicked {
t.Fatalf("expected no panic")
}
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
func TestMergeMiddlewares(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
})
panicHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic(fmt.Errorf("oh no"))
})
// ===
tests := []struct {
name string
handler http.Handler
expectedStatusCode int
expectedPieces []string
}{
{
name: "no panics",
handler: handler,
expectedStatusCode: http.StatusOK,
expectedPieces: []string{
`127.0.0.1 GET "/"`,
},
},
{
name: "panics",
handler: panicHandler,
expectedStatusCode: http.StatusInternalServerError,
expectedPieces: []string{
`oh no`,
`127.0.0.1 GET "/"`,
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var out strings.Builder
middlewares := mergeMiddlewares(
withPanicGuard(),
withLogging(nil, &out),
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("X-Real-Ip", "127.0.0.1")
wrappedHandler := middlewares(test.handler)
var panicked bool
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
wrappedHandler.ServeHTTP(w, req)
}()
if panicked {
t.Fatalf("expected no panic")
}
if w.Code != test.expectedStatusCode {
t.Errorf("expected response status to be %d, got %d", test.expectedStatusCode, w.Code)
}
for _, expectedPiece := range test.expectedPieces {
if !strings.Contains(out.String(), expectedPiece) {
t.Errorf("expected %q, got %q", expectedPiece, out.String())
}
}
})
}
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
@ -93,28 +94,31 @@ func (s *Server) Run(ctx context.Context, addr string) error {
}
}()
}
loggerMiddleware := withLogging(s.statsd)
middlewares := mergeMiddlewares(
withPanicGuard(),
withLogging(s.statsd, os.Stdout),
)
mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler))
mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler))
mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler))
mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler))
mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler))
mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler))
mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler))
mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler))
mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler))
mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler))
mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler))
mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler))
mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler))
mux.Handle("/api/v1/ping", loggerMiddleware(s.pingHandler))
mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler))
mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler))
mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler))
mux.Handle("/api/v1/submit", middlewares(http.HandlerFunc(s.apiSubmitHandler)))
mux.Handle("/api/v1/get-dump-requests", middlewares(http.HandlerFunc(s.apiGetPendingDumpRequestsHandler)))
mux.Handle("/api/v1/submit-dump", middlewares(http.HandlerFunc(s.apiSubmitDumpHandler)))
mux.Handle("/api/v1/query", middlewares(http.HandlerFunc(s.apiQueryHandler)))
mux.Handle("/api/v1/bootstrap", middlewares(http.HandlerFunc(s.apiBootstrapHandler)))
mux.Handle("/api/v1/register", middlewares(http.HandlerFunc(s.apiRegisterHandler)))
mux.Handle("/api/v1/banner", middlewares(http.HandlerFunc(s.apiBannerHandler)))
mux.Handle("/api/v1/download", middlewares(http.HandlerFunc(s.apiDownloadHandler)))
mux.Handle("/api/v1/trigger-cron", middlewares(http.HandlerFunc(s.triggerCronHandler)))
mux.Handle("/api/v1/get-deletion-requests", middlewares(http.HandlerFunc(s.getDeletionRequestsHandler)))
mux.Handle("/api/v1/add-deletion-request", middlewares(http.HandlerFunc(s.addDeletionRequestHandler)))
mux.Handle("/api/v1/slsa-status", middlewares(http.HandlerFunc(s.slsaStatusHandler)))
mux.Handle("/api/v1/feedback", middlewares(http.HandlerFunc(s.feedbackHandler)))
mux.Handle("/api/v1/ping", middlewares(http.HandlerFunc(s.pingHandler)))
mux.Handle("/healthcheck", middlewares(http.HandlerFunc(s.healthCheckHandler)))
mux.Handle("/internal/api/v1/usage-stats", middlewares(http.HandlerFunc(s.usageStatsHandler)))
mux.Handle("/internal/api/v1/stats", middlewares(http.HandlerFunc(s.statsHandler)))
if s.isTestEnvironment {
mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler))
mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler))
mux.Handle("/api/v1/wipe-db-entries", middlewares(http.HandlerFunc(s.wipeDbEntriesHandler)))
mux.Handle("/api/v1/get-num-connections", middlewares(http.HandlerFunc(s.getNumConnectionsHandler)))
}
httpServer := &http.Server{