diff --git a/backend/server/middleware.go b/backend/server/middleware.go new file mode 100644 index 0000000..5d8d7e7 --- /dev/null +++ b/backend/server/middleware.go @@ -0,0 +1,81 @@ +package main + +import ( + "fmt" + "github.com/DataDog/datadog-go/statsd" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "net/http" + "reflect" + "runtime" + "strings" + "time" +) + +type loggedResponseData struct { + size int +} + +type loggingResponseWriter struct { + http.ResponseWriter + responseData *loggedResponseData +} + +func (r *loggingResponseWriter) Write(b []byte) (int, error) { + size, err := r.ResponseWriter.Write(b) + r.responseData.size += size + return size, err +} + +func (r *loggingResponseWriter) WriteHeader(statusCode int) { + r.ResponseWriter.WriteHeader(statusCode) +} + +func getFunctionName(temp interface{}) string { + strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".") + return strs[len(strs)-1] +} + +func byteCountToString(b int) string { + const unit = 1000 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) +} + +type Middleware func(http.HandlerFunc) http.Handler + +func withLogging(s *statsd.Client) Middleware { + return func(h http.HandlerFunc) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var responseData loggedResponseData + lrw := loggingResponseWriter{ + ResponseWriter: rw, + responseData: &responseData, + } + start := time.Now() + span, ctx := tracer.StartSpanFromContext( + r.Context(), + getFunctionName(h), + tracer.SpanType(ext.SpanTypeSQL), + tracer.ServiceName("hishtory-api"), + ) + defer span.Finish() + + h.ServeHTTP(&lrw, r.WithContext(ctx)) + + duration := time.Since(start) + fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) + if s != nil { + s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0) + s.Incr("hishtory.request", []string{}, 1.0) + } + }) + } +} diff --git a/backend/server/server.go b/backend/server/server.go index 4575350..4cc8377 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -4,35 +4,27 @@ import ( "context" "encoding/json" "fmt" - "html" "io" "log" "math" "net/http" "os" - "reflect" "runtime" "strconv" "strings" "time" - pprofhttp "net/http/pprof" - "github.com/DataDog/datadog-go/statsd" "github.com/ddworken/hishtory/internal/database" "github.com/ddworken/hishtory/shared" _ "github.com/lib/pq" - "github.com/rodaine/table" - httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/DataDog/dd-trace-go.v1/profiler" "gorm.io/gorm" "gorm.io/gorm/logger" ) const ( - PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable" + PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable" + StatsdSocket = "unix:///var/run/datadog/dsd.socket" ) var ( @@ -53,195 +45,6 @@ func getHishtoryVersion(r *http.Request) string { return r.Header.Get("X-Hishtory-Version") } -func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) error { - var usageData []shared.UsageData - usageData, err := GLOBAL_DB.UsageDataFindByUserAndDevice(r.Context(), userId, deviceId) - if err != nil { - return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) - } - if len(usageData) == 0 { - err := GLOBAL_DB.CreateUsageData( - r.Context(), - &shared.UsageData{ - UserId: userId, - DeviceId: deviceId, - LastUsed: time.Now(), - NumEntriesHandled: numEntriesHandled, - Version: getHishtoryVersion(r), - }, - ) - if err != nil { - return fmt.Errorf("db.CreateUsageData: %w", err) - } - } else { - usage := usageData[0] - - if err := GLOBAL_DB.UpdateUsageData(r.Context(), userId, deviceId, time.Now(), getRemoteAddr(r)); err != nil { - return fmt.Errorf("db.UpdateUsageData: %w", err) - } - if numEntriesHandled > 0 { - if err := GLOBAL_DB.UpdateUsageDataForNumEntriesHandled(r.Context(), userId, deviceId, numEntriesHandled); err != nil { - return fmt.Errorf("db.UpdateUsageDataForNumEntriesHandled: %w", err) - } - } - if usage.Version != getHishtoryVersion(r) { - if err := GLOBAL_DB.UpdateUsageDataClientVersion(r.Context(), userId, deviceId, getHishtoryVersion(r)); err != nil { - return fmt.Errorf("db.UpdateUsageDataClientVersion: %w", err) - } - } - } - if isQuery { - if err := GLOBAL_DB.UpdateUsageDataNumberQueries(r.Context(), userId, deviceId); err != nil { - return fmt.Errorf("db.UpdateUsageDataNumberQueries: %w", err) - } - } - - return nil -} - -func usageStatsHandler(w http.ResponseWriter, r *http.Request) { - usageData, err := GLOBAL_DB.UsageDataStats(r.Context()) - if err != nil { - panic(fmt.Errorf("db.UsageDataStats: %w", err)) - } - - tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") - tbl.WithWriter(w) - for _, data := range usageData { - versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") - lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") - tbl.AddRow( - data.RegistrationDate.Format(shared.DateOnly), - data.NumDevices, - data.NumEntries, - data.NumQueries, - data.LastUsedDate.Format(shared.DateOnly), - lastQueryStr, - versions, - data.IpAddresses, - ) - } - tbl.Print() -} - -func statsHandler(w http.ResponseWriter, r *http.Request) { - numDevices, err := GLOBAL_DB.CountAllDevices(r.Context()) - checkGormError(err, 0) - - numEntriesProcessed, err := GLOBAL_DB.UsageDataTotal(r.Context()) - checkGormError(err, 0) - - numDbEntries, err := GLOBAL_DB.CountHistoryEntries(r.Context()) - checkGormError(err, 0) - - oneWeek := time.Hour * 24 * 7 - weeklyActiveInstalls, err := GLOBAL_DB.CountActiveInstalls(r.Context(), oneWeek) - checkGormError(err, 0) - - weeklyQueryUsers, err := GLOBAL_DB.CountQueryUsers(r.Context(), oneWeek) - checkGormError(err, 0) - - lastRegistration, err := GLOBAL_DB.DateOfLastRegistration(r.Context()) - checkGormError(err, 0) - - _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) - _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) - _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) - _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) - _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) - _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) -} - -func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) - if len(entries) == 0 { - return - } - _ = updateUsageData(r, entries[0].UserId, entries[0].DeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false) - - devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId) - checkGormError(err, 0) - - if len(devices) == 0 { - panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) - } - fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - - err = GLOBAL_DB.AddHistoryEntriesForAllDevices(r.Context(), devices, entries) - if err != nil { - panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) - } - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false) - historyEntries, err := GLOBAL_DB.AllHistoryEntriesForUser(r.Context(), userId) - checkGormError(err, 1) - fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } -} - -func apiQueryHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, true) - - // Delete any entries that match a pending deletion request - deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - for _, request := range deletionRequests { - _, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request) - checkGormError(err, 0) - } - - // Then retrieve - historyEntries, err := GLOBAL_DB.HistoryEntriesForDevice(r.Context(), deviceId, 5) - checkGormError(err, 0) - fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } - - // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids - // blocking the entire response. This does have a potential race condition, but that is fine. - if isProductionEnvironment() { - go func() { - span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId) - span.Finish(tracer.WithError(err)) - }() - } else { - err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId) - if err != nil { - panic("failed to increment read counts") - } - } - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.query", []string{}, 1.0) - } -} - func getRemoteAddr(r *http.Request) string { addr, ok := r.Header["X-Real-Ip"] if !ok || len(addr) == 0 { @@ -250,191 +53,6 @@ func getRemoteAddr(r *http.Request) string { return addr[0] } -func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { - if getMaximumNumberOfAllowedUsers() < math.MaxInt { - numDistinctUsers, err := GLOBAL_DB.DistinctUsers(r.Context()) - if err != nil { - panic(fmt.Errorf("db.DistinctUsers: %w", err)) - } - if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { - panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) - } - } - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - existingDevicesCount, err := GLOBAL_DB.CountDevicesForUser(r.Context(), userId) - checkGormError(err, 0) - fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if err := GLOBAL_DB.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { - checkGormError(err, 0) - } - - if existingDevicesCount > 0 { - err := GLOBAL_DB.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) - checkGormError(err, 0) - } - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false) - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - var dumpRequests []*shared.DumpRequest - // Filter out ones requested by the hishtory instance that sent this request - dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - - if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - srcDeviceId := getRequiredQueryParam(r, "source_device_id") - requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) - - // sanity check - for _, entry := range entries { - entry.DeviceId = requestingDeviceId - if entry.UserId != userId { - panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) - } - } - - err = GLOBAL_DB.AddHistoryEntries(r.Context(), entries...) - checkGormError(err, 0) - err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) - checkGormError(err, 0) - _ = updateUsageData(r, userId, srcDeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiBannerHandler(w http.ResponseWriter, r *http.Request) { - commitHash := getRequiredQueryParam(r, "commit_hash") - deviceId := getRequiredQueryParam(r, "device_id") - forcedBanner := r.URL.Query().Get("forced_banner") - fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) - if getHishtoryVersion(r) == "v0.160" { - w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) - return - } - w.Write([]byte(html.EscapeString(forcedBanner))) -} - -func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // Increment the ReadCount - err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId) - checkGormError(err, 0) - - // Return all the deletion requests - deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var request shared.DeletionRequest - - if err := json.Unmarshal(data, &request); err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - request.ReadCount = 0 - fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) - - err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func healthCheckHandler(w http.ResponseWriter, r *http.Request) { - if isProductionEnvironment() { - encHistoryEntryCount, err := GLOBAL_DB.CountHistoryEntries(r.Context()) - checkGormError(err, 0) - if encHistoryEntryCount < 1000 { - panic("Suspiciously few enc history entries!") - } - - deviceCount, err := GLOBAL_DB.CountAllDevices(r.Context()) - checkGormError(err, 0) - if deviceCount < 100 { - panic("Suspiciously few devices!") - } - // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - err = GLOBAL_DB.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{ - EncryptedData: []byte("data"), - Nonce: []byte("nonce"), - DeviceId: "healthcheck_device_id", - UserId: "healthcheck_user_id", - Date: time.Now(), - EncryptedId: "healthcheck_enc_id", - ReadCount: 10000, - }) - checkGormError(err, 0) - } else { - err := GLOBAL_DB.Ping() - if err != nil { - panic(fmt.Errorf("failed to ping DB: %w", err)) - } - } - w.Write([]byte("OK")) -} - -func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { - if r.Host == "api.hishtory.dev" || isProductionEnvironment() { - panic("refusing to wipe the DB for prod") - } - if !isTestEnvironment() { - panic("refusing to wipe the DB non-test environment") - } - - err := GLOBAL_DB.Unsafe_DeleteAllHistoryEntries(r.Context()) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { - stats, err := GLOBAL_DB.Stats() - if err != nil { - panic(err) - } - - _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) -} - func isTestEnvironment() bool { return os.Getenv("HISHTORY_TEST") != "" } @@ -540,16 +158,6 @@ func runBackgroundJobs(ctx context.Context) { } } -func triggerCronHandler(w http.ResponseWriter, r *http.Request) { - err := cron(r.Context()) - if err != nil { - panic(err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - type releaseInfo struct { Name string `json:"name"` } @@ -674,200 +282,22 @@ func buildUpdateInfo(version string) shared.UpdateInfo { } } -func apiDownloadHandler(w http.ResponseWriter, r *http.Request) { - updateInfo := buildUpdateInfo(ReleaseVersion) - resp, err := json.Marshal(updateInfo) - if err != nil { - panic(err) - } - w.Write(resp) -} - -func slsaStatusHandler(w http.ResponseWriter, r *http.Request) { - // returns "OK" unless there is a current SLSA bug - v := getHishtoryVersion(r) - if !strings.Contains(v, "v0.") { - w.Write([]byte("OK")) - return - } - vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) - if err != nil { - w.Write([]byte("OK")) - return - } - if vNum < 159 { - w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) - return - } - w.Write([]byte("OK")) -} - -func feedbackHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var feedback shared.Feedback - err = json.Unmarshal(data, &feedback) - if err != nil { - panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) - } - fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback) - checkGormError(err, 0) - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -type loggedResponseData struct { - size int -} - -type loggingResponseWriter struct { - http.ResponseWriter - responseData *loggedResponseData -} - -func (r *loggingResponseWriter) Write(b []byte) (int, error) { - size, err := r.ResponseWriter.Write(b) - r.responseData.size += size - return size, err -} - -func (r *loggingResponseWriter) WriteHeader(statusCode int) { - r.ResponseWriter.WriteHeader(statusCode) -} - -func getFunctionName(temp interface{}) string { - strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".") - return strs[len(strs)-1] -} - -func withLogging(h http.HandlerFunc) http.Handler { - logFn := func(rw http.ResponseWriter, r *http.Request) { - var responseData loggedResponseData - lrw := loggingResponseWriter{ - ResponseWriter: rw, - responseData: &responseData, - } - start := time.Now() - span, ctx := tracer.StartSpanFromContext( - r.Context(), - getFunctionName(h), - tracer.SpanType(ext.SpanTypeSQL), - tracer.ServiceName("hishtory-api"), - ) - defer span.Finish() - - h(&lrw, r.WithContext(ctx)) - - duration := time.Since(start) - fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0) - GLOBAL_STATSD.Incr("hishtory.request", []string{}, 1.0) - } - } - return http.HandlerFunc(logFn) -} - -func byteCountToString(b int) string { - const unit = 1000 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) -} - -func configureObservability(mux *httptrace.ServeMux) func() { - // Profiler - err := profiler.Start( - profiler.WithService("hishtory-api"), - profiler.WithVersion(ReleaseVersion), - profiler.WithAPIKey(os.Getenv("DD_API_KEY")), - profiler.WithUDS("/var/run/datadog/apm.socket"), - profiler.WithProfileTypes( - profiler.CPUProfile, - profiler.HeapProfile, - ), - ) - if err != nil { - fmt.Printf("Failed to start DataDog profiler: %v\n", err) - } - // Tracer - tracer.Start( - tracer.WithRuntimeMetrics(), - tracer.WithService("hishtory-api"), - tracer.WithUDS("/var/run/datadog/apm.socket"), - ) - defer tracer.Stop() - // Stats - ddStats, err := statsd.New("unix:///var/run/datadog/dsd.socket") +func main() { + s, err := statsd.New(StatsdSocket) if err != nil { fmt.Printf("Failed to start DataDog statsd: %v\n", err) } - GLOBAL_STATSD = ddStats - // Pprof - mux.HandleFunc("/debug/pprof/", pprofhttp.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) - // Func to stop all of the above - return func() { - profiler.Stop() - tracer.Stop() + // TODO: remove this global once we have a better way to pass it around + GLOBAL_STATSD = s + + srv := NewServer(GLOBAL_DB, WithStatsd(s)) + + if err := srv.Run(context.Background(), ":8080"); err != nil { + panic(err) } } -func main() { - mux := httptrace.NewServeMux() - - if isProductionEnvironment() { - defer configureObservability(mux)() - go func() { - if err := GLOBAL_DB.DeepClean(context.Background()); err != nil { - panic(err) - } - }() - } - - mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler)) - mux.Handle("/api/v1/get-dump-requests", withLogging(apiGetPendingDumpRequestsHandler)) - mux.Handle("/api/v1/submit-dump", withLogging(apiSubmitDumpHandler)) - mux.Handle("/api/v1/query", withLogging(apiQueryHandler)) - mux.Handle("/api/v1/bootstrap", withLogging(apiBootstrapHandler)) - mux.Handle("/api/v1/register", withLogging(apiRegisterHandler)) - mux.Handle("/api/v1/banner", withLogging(apiBannerHandler)) - mux.Handle("/api/v1/download", withLogging(apiDownloadHandler)) - mux.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler)) - mux.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler)) - mux.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler)) - mux.Handle("/api/v1/slsa-status", withLogging(slsaStatusHandler)) - mux.Handle("/api/v1/feedback", withLogging(feedbackHandler)) - mux.Handle("/healthcheck", withLogging(healthCheckHandler)) - mux.Handle("/internal/api/v1/usage-stats", withLogging(usageStatsHandler)) - mux.Handle("/internal/api/v1/stats", withLogging(statsHandler)) - if isTestEnvironment() { - mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler)) - mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler)) - } - - fmt.Println("Listening on localhost:8080") - log.Fatal(http.ListenAndServe(":8080", mux)) -} - func checkGormResult(result *gorm.DB) { checkGormError(result.Error, 1) } diff --git a/backend/server/server_test.go b/backend/server/server_test.go index a027216..9187409 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -24,6 +24,7 @@ import ( func TestESubmitThenQuery(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register a few devices userId := data.UserId("key") @@ -32,11 +33,11 @@ func TestESubmitThenQuery(t *testing.T) { otherUser := data.UserId("otherkey") otherDev := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Submit a few entries for different devices entry := testutils.MakeFakeHistoryEntry("ls ~/") @@ -45,12 +46,12 @@ func TestESubmitThenQuery(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -79,7 +80,7 @@ func TestESubmitThenQuery(t *testing.T) { // Same for device id 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -107,7 +108,7 @@ func TestESubmitThenQuery(t *testing.T) { // Bootstrap handler should return 2 entries, one for each device w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil) - apiBootstrapHandler(w, searchReq) + s.apiBootstrapHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -124,6 +125,7 @@ func TestESubmitThenQuery(t *testing.T) { func TestDumpRequestAndResponse(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -133,17 +135,17 @@ func TestDumpRequestAndResponse(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Query for dump requests, there should be one for userId w := httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -163,7 +165,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And one for otherUser w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -183,7 +185,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none if we query for a user ID that doesn't exit w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -193,7 +195,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none for a missing user ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -211,11 +213,11 @@ func TestDumpRequestAndResponse(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody)) - apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) // Check that the dump request is no longer there for userId for either device ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -226,7 +228,7 @@ func TestDumpRequestAndResponse(t *testing.T) { w = httptest.NewRecorder() // The other user - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -236,7 +238,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // But it is there for the other user w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -257,7 +259,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And finally, query to ensure that the dumped entries are in the DB w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -323,6 +325,7 @@ func TestUpdateReleaseVersion(t *testing.T) { func TestDeletionRequests(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -332,13 +335,13 @@ func TestDeletionRequests(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Add an entry for user1 entry1 := testutils.MakeFakeHistoryEntry("ls ~/") @@ -348,7 +351,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // And another entry for user1 entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -358,7 +361,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // And an entry for user2 that has the same timestamp as the previous entry entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -369,12 +372,12 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -413,13 +416,13 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal(delReq) testutils.Check(t, err) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - addDeletionRequestHandler(httptest.NewRecorder(), req) + s.addDeletionRequestHandler(httptest.NewRecorder(), req) // Query again for device id 1 and get a single result time.Sleep(10 * time.Millisecond) w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -447,7 +450,7 @@ func TestDeletionRequests(t *testing.T) { // Query for user 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -475,7 +478,7 @@ func TestDeletionRequests(t *testing.T) { // Query for deletion requests w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - getDeletionRequestsHandler(w, searchReq) + s.getDeletionRequestsHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -504,8 +507,9 @@ func TestDeletionRequests(t *testing.T) { } func TestHealthcheck(t *testing.T) { + s := NewServer(GLOBAL_DB) w := httptest.NewRecorder() - healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) + s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { t.Fatalf("expected 200 resp code for healthCheckHandler") } @@ -524,6 +528,7 @@ func TestHealthcheck(t *testing.T) { func TestLimitRegistrations(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() @@ -531,28 +536,29 @@ func TestLimitRegistrations(t *testing.T) { // Register three devices across two users deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // And this next one should fail since it is a new user defer func() { _ = recover() }() deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) t.Errorf("expected panic") } func TestCleanDatabaseNoErrors(t *testing.T) { // Init InitDB() + s := NewServer(GLOBAL_DB) // Create a user and an entry userId := data.UserId("dkey") devId1 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1.DeviceId = devId1 encEntry, err := data.EncryptHistoryEntry("dkey", entry1) @@ -560,7 +566,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics testutils.Check(t, GLOBAL_DB.Clean(context.TODO())) diff --git a/backend/server/srv.go b/backend/server/srv.go new file mode 100644 index 0000000..c373167 --- /dev/null +++ b/backend/server/srv.go @@ -0,0 +1,611 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "math" + "net/http" + pprofhttp "net/http/pprof" + "os" + "strconv" + "strings" + "time" + + "github.com/DataDog/datadog-go/statsd" + "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/shared" + "github.com/rodaine/table" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/profiler" +) + +type Srv struct { + db *database.DB + statsd *statsd.Client +} + +type ServerOption func(*Srv) + +func WithStatsd(statsd *statsd.Client) ServerOption { + return func(s *Srv) { + s.statsd = statsd + } +} + +func NewServer(db *database.DB, options ...ServerOption) *Srv { + srv := Srv{db: db} + for _, option := range options { + option(&srv) + } + return &srv +} + +func (s *Srv) Run(ctx context.Context, addr string) error { + mux := httptrace.NewServeMux() + + if isProductionEnvironment() { + defer configureObservability(mux)() + go func() { + if err := s.db.DeepClean(ctx); err != nil { + panic(err) + } + }() + } + loggerMiddleware := withLogging(s.statsd) + + mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) + mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) + mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) + mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) + mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) + mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) + mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) + mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) + mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) + mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) + mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) + mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) + mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) + mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) + mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) + mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) + if isTestEnvironment() { + mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) + mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) + } + + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + } + + fmt.Printf("Listening on %s\n", addr) + if err := httpServer.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http.ListenAndServe: %w", err) + } + } + + return nil +} + +func (s *Srv) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) + if len(entries) == 0 { + return + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) + checkGormError(err, 0) + + if len(devices) == 0 { + panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) + } + fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) + + err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) + if err != nil { + panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) + } + if s.statsd != nil { + s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) + checkGormError(err, 1) + fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } +} + +func (s *Srv) apiQueryHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + // Delete any entries that match a pending deletion request + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + for _, request := range deletionRequests { + _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) + checkGormError(err, 0) + } + + // Then retrieve + historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + checkGormError(err, 0) + fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } + + // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids + // blocking the entire response. This does have a potential race condition, but that is fine. + if isProductionEnvironment() { + go func() { + span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + span.Finish(tracer.WithError(err)) + }() + } else { + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + if err != nil { + panic("failed to increment read counts") + } + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.query", []string{}, 1.0) + } +} + +func (s *Srv) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + srcDeviceId := getRequiredQueryParam(r, "source_device_id") + requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) + + // sanity check + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) + } + } + + err = s.db.EncHistoryCreateMulti(r.Context(), entries...) + checkGormError(err, 0) + err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) + checkGormError(err, 0) + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) apiBannerHandler(w http.ResponseWriter, r *http.Request) { + commitHash := getRequiredQueryParam(r, "commit_hash") + deviceId := getRequiredQueryParam(r, "device_id") + forcedBanner := r.URL.Query().Get("forced_banner") + fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) + if getHishtoryVersion(r) == "v0.160" { + w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) + return + } + w.Write([]byte(html.EscapeString(forcedBanner))) +} + +func (s *Srv) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + var dumpRequests []*shared.DumpRequest + // Filter out ones requested by the hishtory instance that sent this request + dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + + if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Srv) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // Increment the ReadCount + err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) + checkGormError(err, 0) + + // Return all the deletion requests + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Srv) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var request shared.DeletionRequest + + if err := json.Unmarshal(data, &request); err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + request.ReadCount = 0 + fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) + + err = s.db.DeletionRequestCreate(r.Context(), &request) + checkGormError(err, 0) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (_ *Srv) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { + updateInfo := buildUpdateInfo(ReleaseVersion) + resp, err := json.Marshal(updateInfo) + if err != nil { + panic(err) + } + w.Write(resp) +} + +func (s *Srv) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + numDistinctUsers, err := s.db.DistinctUsers(r.Context()) + if err != nil { + panic(fmt.Errorf("db.DistinctUsers: %w", err)) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) + checkGormError(err, 0) + fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) + if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { + checkGormError(err, 0) + } + + if existingDevicesCount > 0 { + err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) + checkGormError(err, 0) + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.register", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) triggerCronHandler(w http.ResponseWriter, r *http.Request) { + err := cron(r.Context()) + if err != nil { + panic(err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { + // returns "OK" unless there is a current SLSA bug + v := getHishtoryVersion(r) + if !strings.Contains(v, "v0.") { + w.Write([]byte("OK")) + return + } + vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) + if err != nil { + w.Write([]byte("OK")) + return + } + if vNum < 159 { + w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) + return + } + w.Write([]byte("OK")) +} + +func (s *Srv) feedbackHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var feedback shared.Feedback + err = json.Unmarshal(data, &feedback) + if err != nil { + panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) + } + fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) + err = s.db.FeedbackCreate(r.Context(), &feedback) + checkGormError(err, 0) + + if s.statsd != nil { + s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + if isProductionEnvironment() { + // Check that we have a reasonable looking set of devices/entries in the DB + //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() + //if err != nil { + // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) + //} + //defer rows.Close() + //if !rows.Next() { + // panic("Suspiciously few enc history entries!") + //} + encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err, 0) + if encHistoryEntryCount < 1000 { + panic("Suspiciously few enc history entries!") + } + + deviceCount, err := s.db.DevicesCount(r.Context()) + checkGormError(err, 0) + if deviceCount < 100 { + panic("Suspiciously few devices!") + } + // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. + err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ + EncryptedData: []byte("data"), + Nonce: []byte("nonce"), + DeviceId: "healthcheck_device_id", + UserId: "healthcheck_user_id", + Date: time.Now(), + EncryptedId: "healthcheck_enc_id", + ReadCount: 10000, + }) + checkGormError(err, 0) + } else { + err := s.db.Ping() + if err != nil { + panic(fmt.Errorf("failed to ping DB: %w", err)) + } + } + w.Write([]byte("OK")) +} + +func (s *Srv) usageStatsHandler(w http.ResponseWriter, r *http.Request) { + usageData, err := s.db.UsageDataStats(r.Context()) + if err != nil { + panic(fmt.Errorf("db.UsageDataStats: %w", err)) + } + + tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") + tbl.WithWriter(w) + for _, data := range usageData { + versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") + lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") + tbl.AddRow( + data.RegistrationDate.Format(shared.DateOnly), + data.NumDevices, + data.NumEntries, + data.NumQueries, + data.LastUsedDate.Format(shared.DateOnly), + lastQueryStr, + versions, + data.IpAddresses, + ) + } + tbl.Print() +} + +func (s *Srv) statsHandler(w http.ResponseWriter, r *http.Request) { + numDevices, err := s.db.DevicesCount(r.Context()) + checkGormError(err, 0) + + numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) + checkGormError(err, 0) + + numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err, 0) + + oneWeek := time.Hour * 24 * 7 + weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) + checkGormError(err, 0) + + weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) + checkGormError(err, 0) + + lastRegistration, err := s.db.LastRegistration(r.Context()) + checkGormError(err, 0) + + _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) + _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) + _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) + _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) + _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) + _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) +} + +func (s *Srv) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { + if r.Host == "api.hishtory.dev" || isProductionEnvironment() { + panic("refusing to wipe the DB for prod") + } + if !isTestEnvironment() { + panic("refusing to wipe the DB non-test environment") + } + + err := s.db.EncHistoryClear(r.Context()) + checkGormError(err, 0) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { + stats, err := s.db.Stats() + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) +} + +func (s *Srv) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { + var usageData []shared.UsageData + usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) + if err != nil { + return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) + } + if len(usageData) == 0 { + err := s.db.UsageDataCreate( + ctx, + &shared.UsageData{ + UserId: userId, + DeviceId: deviceId, + LastUsed: time.Now(), + NumEntriesHandled: numEntriesHandled, + Version: version, + }, + ) + if err != nil { + return fmt.Errorf("db.UsageDataCreate: %w", err) + } + } else { + usage := usageData[0] + + if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { + return fmt.Errorf("db.UsageDataUpdate: %w", err) + } + if numEntriesHandled > 0 { + if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) + } + } + if usage.Version != version { + if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { + return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) + } + } + } + if isQuery { + if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) + } + } + + return nil +} + +func configureObservability(mux *httptrace.ServeMux) func() { + // Profiler + err := profiler.Start( + profiler.WithService("hishtory-api"), + profiler.WithVersion(ReleaseVersion), + profiler.WithAPIKey(os.Getenv("DD_API_KEY")), + profiler.WithUDS("/var/run/datadog/apm.socket"), + profiler.WithProfileTypes( + profiler.CPUProfile, + profiler.HeapProfile, + ), + ) + if err != nil { + fmt.Printf("Failed to start DataDog profiler: %v\n", err) + } + // Tracer + tracer.Start( + tracer.WithRuntimeMetrics(), + tracer.WithService("hishtory-api"), + tracer.WithUDS("/var/run/datadog/apm.socket"), + ) + // TODO: should this be here? + defer tracer.Stop() + + // Pprof + mux.HandleFunc("/debug/pprof/", pprofhttp.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) + + // Func to stop all of the above + return func() { + profiler.Stop() + tracer.Stop() + } +}