From 02b1e8287d1efcb6d04d7f7233199a70ebcba06e Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 09:26:20 -0400 Subject: [PATCH 1/6] isolate all server handlers into a single struct, without using global variables --- backend/server/middleware.go | 81 +++++ backend/server/server.go | 592 +------------------------------- backend/server/server_test.go | 84 ++--- backend/server/srv.go | 611 ++++++++++++++++++++++++++++++++++ 4 files changed, 748 insertions(+), 620 deletions(-) create mode 100644 backend/server/middleware.go create mode 100644 backend/server/srv.go diff --git a/backend/server/middleware.go b/backend/server/middleware.go new file mode 100644 index 0000000..5d8d7e7 --- /dev/null +++ b/backend/server/middleware.go @@ -0,0 +1,81 @@ +package main + +import ( + "fmt" + "github.com/DataDog/datadog-go/statsd" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "net/http" + "reflect" + "runtime" + "strings" + "time" +) + +type loggedResponseData struct { + size int +} + +type loggingResponseWriter struct { + http.ResponseWriter + responseData *loggedResponseData +} + +func (r *loggingResponseWriter) Write(b []byte) (int, error) { + size, err := r.ResponseWriter.Write(b) + r.responseData.size += size + return size, err +} + +func (r *loggingResponseWriter) WriteHeader(statusCode int) { + r.ResponseWriter.WriteHeader(statusCode) +} + +func getFunctionName(temp interface{}) string { + strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".") + return strs[len(strs)-1] +} + +func byteCountToString(b int) string { + const unit = 1000 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := int64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) +} + +type Middleware func(http.HandlerFunc) http.Handler + +func withLogging(s *statsd.Client) Middleware { + return func(h http.HandlerFunc) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var responseData loggedResponseData + lrw := loggingResponseWriter{ + ResponseWriter: rw, + responseData: &responseData, + } + start := time.Now() + span, ctx := tracer.StartSpanFromContext( + r.Context(), + getFunctionName(h), + tracer.SpanType(ext.SpanTypeSQL), + tracer.ServiceName("hishtory-api"), + ) + defer span.Finish() + + h.ServeHTTP(&lrw, r.WithContext(ctx)) + + duration := time.Since(start) + fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) + if s != nil { + s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0) + s.Incr("hishtory.request", []string{}, 1.0) + } + }) + } +} diff --git a/backend/server/server.go b/backend/server/server.go index 4575350..4cc8377 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -4,35 +4,27 @@ import ( "context" "encoding/json" "fmt" - "html" "io" "log" "math" "net/http" "os" - "reflect" "runtime" "strconv" "strings" "time" - pprofhttp "net/http/pprof" - "github.com/DataDog/datadog-go/statsd" "github.com/ddworken/hishtory/internal/database" "github.com/ddworken/hishtory/shared" _ "github.com/lib/pq" - "github.com/rodaine/table" - httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/DataDog/dd-trace-go.v1/profiler" "gorm.io/gorm" "gorm.io/gorm/logger" ) const ( - PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable" + PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable" + StatsdSocket = "unix:///var/run/datadog/dsd.socket" ) var ( @@ -53,195 +45,6 @@ func getHishtoryVersion(r *http.Request) string { return r.Header.Get("X-Hishtory-Version") } -func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) error { - var usageData []shared.UsageData - usageData, err := GLOBAL_DB.UsageDataFindByUserAndDevice(r.Context(), userId, deviceId) - if err != nil { - return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) - } - if len(usageData) == 0 { - err := GLOBAL_DB.CreateUsageData( - r.Context(), - &shared.UsageData{ - UserId: userId, - DeviceId: deviceId, - LastUsed: time.Now(), - NumEntriesHandled: numEntriesHandled, - Version: getHishtoryVersion(r), - }, - ) - if err != nil { - return fmt.Errorf("db.CreateUsageData: %w", err) - } - } else { - usage := usageData[0] - - if err := GLOBAL_DB.UpdateUsageData(r.Context(), userId, deviceId, time.Now(), getRemoteAddr(r)); err != nil { - return fmt.Errorf("db.UpdateUsageData: %w", err) - } - if numEntriesHandled > 0 { - if err := GLOBAL_DB.UpdateUsageDataForNumEntriesHandled(r.Context(), userId, deviceId, numEntriesHandled); err != nil { - return fmt.Errorf("db.UpdateUsageDataForNumEntriesHandled: %w", err) - } - } - if usage.Version != getHishtoryVersion(r) { - if err := GLOBAL_DB.UpdateUsageDataClientVersion(r.Context(), userId, deviceId, getHishtoryVersion(r)); err != nil { - return fmt.Errorf("db.UpdateUsageDataClientVersion: %w", err) - } - } - } - if isQuery { - if err := GLOBAL_DB.UpdateUsageDataNumberQueries(r.Context(), userId, deviceId); err != nil { - return fmt.Errorf("db.UpdateUsageDataNumberQueries: %w", err) - } - } - - return nil -} - -func usageStatsHandler(w http.ResponseWriter, r *http.Request) { - usageData, err := GLOBAL_DB.UsageDataStats(r.Context()) - if err != nil { - panic(fmt.Errorf("db.UsageDataStats: %w", err)) - } - - tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") - tbl.WithWriter(w) - for _, data := range usageData { - versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") - lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") - tbl.AddRow( - data.RegistrationDate.Format(shared.DateOnly), - data.NumDevices, - data.NumEntries, - data.NumQueries, - data.LastUsedDate.Format(shared.DateOnly), - lastQueryStr, - versions, - data.IpAddresses, - ) - } - tbl.Print() -} - -func statsHandler(w http.ResponseWriter, r *http.Request) { - numDevices, err := GLOBAL_DB.CountAllDevices(r.Context()) - checkGormError(err, 0) - - numEntriesProcessed, err := GLOBAL_DB.UsageDataTotal(r.Context()) - checkGormError(err, 0) - - numDbEntries, err := GLOBAL_DB.CountHistoryEntries(r.Context()) - checkGormError(err, 0) - - oneWeek := time.Hour * 24 * 7 - weeklyActiveInstalls, err := GLOBAL_DB.CountActiveInstalls(r.Context(), oneWeek) - checkGormError(err, 0) - - weeklyQueryUsers, err := GLOBAL_DB.CountQueryUsers(r.Context(), oneWeek) - checkGormError(err, 0) - - lastRegistration, err := GLOBAL_DB.DateOfLastRegistration(r.Context()) - checkGormError(err, 0) - - _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) - _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) - _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) - _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) - _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) - _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) -} - -func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) - if len(entries) == 0 { - return - } - _ = updateUsageData(r, entries[0].UserId, entries[0].DeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false) - - devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId) - checkGormError(err, 0) - - if len(devices) == 0 { - panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) - } - fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - - err = GLOBAL_DB.AddHistoryEntriesForAllDevices(r.Context(), devices, entries) - if err != nil { - panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) - } - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false) - historyEntries, err := GLOBAL_DB.AllHistoryEntriesForUser(r.Context(), userId) - checkGormError(err, 1) - fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } -} - -func apiQueryHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, true) - - // Delete any entries that match a pending deletion request - deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - for _, request := range deletionRequests { - _, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request) - checkGormError(err, 0) - } - - // Then retrieve - historyEntries, err := GLOBAL_DB.HistoryEntriesForDevice(r.Context(), deviceId, 5) - checkGormError(err, 0) - fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } - - // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids - // blocking the entire response. This does have a potential race condition, but that is fine. - if isProductionEnvironment() { - go func() { - span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId) - span.Finish(tracer.WithError(err)) - }() - } else { - err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId) - if err != nil { - panic("failed to increment read counts") - } - } - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.query", []string{}, 1.0) - } -} - func getRemoteAddr(r *http.Request) string { addr, ok := r.Header["X-Real-Ip"] if !ok || len(addr) == 0 { @@ -250,191 +53,6 @@ func getRemoteAddr(r *http.Request) string { return addr[0] } -func apiRegisterHandler(w http.ResponseWriter, r *http.Request) { - if getMaximumNumberOfAllowedUsers() < math.MaxInt { - numDistinctUsers, err := GLOBAL_DB.DistinctUsers(r.Context()) - if err != nil { - panic(fmt.Errorf("db.DistinctUsers: %w", err)) - } - if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { - panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) - } - } - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - existingDevicesCount, err := GLOBAL_DB.CountDevicesForUser(r.Context(), userId) - checkGormError(err, 0) - fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if err := GLOBAL_DB.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { - checkGormError(err, 0) - } - - if existingDevicesCount > 0 { - err := GLOBAL_DB.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) - checkGormError(err, 0) - } - _ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false) - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - var dumpRequests []*shared.DumpRequest - // Filter out ones requested by the hishtory instance that sent this request - dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - - if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - srcDeviceId := getRequiredQueryParam(r, "source_device_id") - requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) - - // sanity check - for _, entry := range entries { - entry.DeviceId = requestingDeviceId - if entry.UserId != userId { - panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) - } - } - - err = GLOBAL_DB.AddHistoryEntries(r.Context(), entries...) - checkGormError(err, 0) - err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) - checkGormError(err, 0) - _ = updateUsageData(r, userId, srcDeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func apiBannerHandler(w http.ResponseWriter, r *http.Request) { - commitHash := getRequiredQueryParam(r, "commit_hash") - deviceId := getRequiredQueryParam(r, "device_id") - forcedBanner := r.URL.Query().Get("forced_banner") - fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) - if getHishtoryVersion(r) == "v0.160" { - w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) - return - } - w.Write([]byte(html.EscapeString(forcedBanner))) -} - -func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // Increment the ReadCount - err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId) - checkGormError(err, 0) - - // Return all the deletion requests - deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var request shared.DeletionRequest - - if err := json.Unmarshal(data, &request); err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - request.ReadCount = 0 - fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) - - err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func healthCheckHandler(w http.ResponseWriter, r *http.Request) { - if isProductionEnvironment() { - encHistoryEntryCount, err := GLOBAL_DB.CountHistoryEntries(r.Context()) - checkGormError(err, 0) - if encHistoryEntryCount < 1000 { - panic("Suspiciously few enc history entries!") - } - - deviceCount, err := GLOBAL_DB.CountAllDevices(r.Context()) - checkGormError(err, 0) - if deviceCount < 100 { - panic("Suspiciously few devices!") - } - // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - err = GLOBAL_DB.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{ - EncryptedData: []byte("data"), - Nonce: []byte("nonce"), - DeviceId: "healthcheck_device_id", - UserId: "healthcheck_user_id", - Date: time.Now(), - EncryptedId: "healthcheck_enc_id", - ReadCount: 10000, - }) - checkGormError(err, 0) - } else { - err := GLOBAL_DB.Ping() - if err != nil { - panic(fmt.Errorf("failed to ping DB: %w", err)) - } - } - w.Write([]byte("OK")) -} - -func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { - if r.Host == "api.hishtory.dev" || isProductionEnvironment() { - panic("refusing to wipe the DB for prod") - } - if !isTestEnvironment() { - panic("refusing to wipe the DB non-test environment") - } - - err := GLOBAL_DB.Unsafe_DeleteAllHistoryEntries(r.Context()) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { - stats, err := GLOBAL_DB.Stats() - if err != nil { - panic(err) - } - - _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) -} - func isTestEnvironment() bool { return os.Getenv("HISHTORY_TEST") != "" } @@ -540,16 +158,6 @@ func runBackgroundJobs(ctx context.Context) { } } -func triggerCronHandler(w http.ResponseWriter, r *http.Request) { - err := cron(r.Context()) - if err != nil { - panic(err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - type releaseInfo struct { Name string `json:"name"` } @@ -674,200 +282,22 @@ func buildUpdateInfo(version string) shared.UpdateInfo { } } -func apiDownloadHandler(w http.ResponseWriter, r *http.Request) { - updateInfo := buildUpdateInfo(ReleaseVersion) - resp, err := json.Marshal(updateInfo) - if err != nil { - panic(err) - } - w.Write(resp) -} - -func slsaStatusHandler(w http.ResponseWriter, r *http.Request) { - // returns "OK" unless there is a current SLSA bug - v := getHishtoryVersion(r) - if !strings.Contains(v, "v0.") { - w.Write([]byte("OK")) - return - } - vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) - if err != nil { - w.Write([]byte("OK")) - return - } - if vNum < 159 { - w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) - return - } - w.Write([]byte("OK")) -} - -func feedbackHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var feedback shared.Feedback - err = json.Unmarshal(data, &feedback) - if err != nil { - panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) - } - fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback) - checkGormError(err, 0) - - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -type loggedResponseData struct { - size int -} - -type loggingResponseWriter struct { - http.ResponseWriter - responseData *loggedResponseData -} - -func (r *loggingResponseWriter) Write(b []byte) (int, error) { - size, err := r.ResponseWriter.Write(b) - r.responseData.size += size - return size, err -} - -func (r *loggingResponseWriter) WriteHeader(statusCode int) { - r.ResponseWriter.WriteHeader(statusCode) -} - -func getFunctionName(temp interface{}) string { - strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".") - return strs[len(strs)-1] -} - -func withLogging(h http.HandlerFunc) http.Handler { - logFn := func(rw http.ResponseWriter, r *http.Request) { - var responseData loggedResponseData - lrw := loggingResponseWriter{ - ResponseWriter: rw, - responseData: &responseData, - } - start := time.Now() - span, ctx := tracer.StartSpanFromContext( - r.Context(), - getFunctionName(h), - tracer.SpanType(ext.SpanTypeSQL), - tracer.ServiceName("hishtory-api"), - ) - defer span.Finish() - - h(&lrw, r.WithContext(ctx)) - - duration := time.Since(start) - fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size)) - if GLOBAL_STATSD != nil { - GLOBAL_STATSD.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0) - GLOBAL_STATSD.Incr("hishtory.request", []string{}, 1.0) - } - } - return http.HandlerFunc(logFn) -} - -func byteCountToString(b int) string { - const unit = 1000 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) -} - -func configureObservability(mux *httptrace.ServeMux) func() { - // Profiler - err := profiler.Start( - profiler.WithService("hishtory-api"), - profiler.WithVersion(ReleaseVersion), - profiler.WithAPIKey(os.Getenv("DD_API_KEY")), - profiler.WithUDS("/var/run/datadog/apm.socket"), - profiler.WithProfileTypes( - profiler.CPUProfile, - profiler.HeapProfile, - ), - ) - if err != nil { - fmt.Printf("Failed to start DataDog profiler: %v\n", err) - } - // Tracer - tracer.Start( - tracer.WithRuntimeMetrics(), - tracer.WithService("hishtory-api"), - tracer.WithUDS("/var/run/datadog/apm.socket"), - ) - defer tracer.Stop() - // Stats - ddStats, err := statsd.New("unix:///var/run/datadog/dsd.socket") +func main() { + s, err := statsd.New(StatsdSocket) if err != nil { fmt.Printf("Failed to start DataDog statsd: %v\n", err) } - GLOBAL_STATSD = ddStats - // Pprof - mux.HandleFunc("/debug/pprof/", pprofhttp.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) - // Func to stop all of the above - return func() { - profiler.Stop() - tracer.Stop() + // TODO: remove this global once we have a better way to pass it around + GLOBAL_STATSD = s + + srv := NewServer(GLOBAL_DB, WithStatsd(s)) + + if err := srv.Run(context.Background(), ":8080"); err != nil { + panic(err) } } -func main() { - mux := httptrace.NewServeMux() - - if isProductionEnvironment() { - defer configureObservability(mux)() - go func() { - if err := GLOBAL_DB.DeepClean(context.Background()); err != nil { - panic(err) - } - }() - } - - mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler)) - mux.Handle("/api/v1/get-dump-requests", withLogging(apiGetPendingDumpRequestsHandler)) - mux.Handle("/api/v1/submit-dump", withLogging(apiSubmitDumpHandler)) - mux.Handle("/api/v1/query", withLogging(apiQueryHandler)) - mux.Handle("/api/v1/bootstrap", withLogging(apiBootstrapHandler)) - mux.Handle("/api/v1/register", withLogging(apiRegisterHandler)) - mux.Handle("/api/v1/banner", withLogging(apiBannerHandler)) - mux.Handle("/api/v1/download", withLogging(apiDownloadHandler)) - mux.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler)) - mux.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler)) - mux.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler)) - mux.Handle("/api/v1/slsa-status", withLogging(slsaStatusHandler)) - mux.Handle("/api/v1/feedback", withLogging(feedbackHandler)) - mux.Handle("/healthcheck", withLogging(healthCheckHandler)) - mux.Handle("/internal/api/v1/usage-stats", withLogging(usageStatsHandler)) - mux.Handle("/internal/api/v1/stats", withLogging(statsHandler)) - if isTestEnvironment() { - mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler)) - mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler)) - } - - fmt.Println("Listening on localhost:8080") - log.Fatal(http.ListenAndServe(":8080", mux)) -} - func checkGormResult(result *gorm.DB) { checkGormError(result.Error, 1) } diff --git a/backend/server/server_test.go b/backend/server/server_test.go index a027216..9187409 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -24,6 +24,7 @@ import ( func TestESubmitThenQuery(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register a few devices userId := data.UserId("key") @@ -32,11 +33,11 @@ func TestESubmitThenQuery(t *testing.T) { otherUser := data.UserId("otherkey") otherDev := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Submit a few entries for different devices entry := testutils.MakeFakeHistoryEntry("ls ~/") @@ -45,12 +46,12 @@ func TestESubmitThenQuery(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -79,7 +80,7 @@ func TestESubmitThenQuery(t *testing.T) { // Same for device id 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -107,7 +108,7 @@ func TestESubmitThenQuery(t *testing.T) { // Bootstrap handler should return 2 entries, one for each device w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil) - apiBootstrapHandler(w, searchReq) + s.apiBootstrapHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -124,6 +125,7 @@ func TestESubmitThenQuery(t *testing.T) { func TestDumpRequestAndResponse(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -133,17 +135,17 @@ func TestDumpRequestAndResponse(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Query for dump requests, there should be one for userId w := httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -163,7 +165,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And one for otherUser w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -183,7 +185,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none if we query for a user ID that doesn't exit w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -193,7 +195,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And none for a missing user ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -211,11 +213,11 @@ func TestDumpRequestAndResponse(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody)) - apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq) // Check that the dump request is no longer there for userId for either device ID w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -226,7 +228,7 @@ func TestDumpRequestAndResponse(t *testing.T) { w = httptest.NewRecorder() // The other user - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -236,7 +238,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // But it is there for the other user w = httptest.NewRecorder() - apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) + s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil)) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -257,7 +259,7 @@ func TestDumpRequestAndResponse(t *testing.T) { // And finally, query to ensure that the dumped entries are in the DB w = httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -323,6 +325,7 @@ func TestUpdateReleaseVersion(t *testing.T) { func TestDeletionRequests(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -332,13 +335,13 @@ func TestDeletionRequests(t *testing.T) { otherDev1 := uuid.Must(uuid.NewRandom()).String() otherDev2 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // Add an entry for user1 entry1 := testutils.MakeFakeHistoryEntry("ls ~/") @@ -348,7 +351,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // And another entry for user1 entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -358,7 +361,7 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // And an entry for user2 that has the same timestamp as the previous entry entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar") @@ -369,12 +372,12 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Query for device id 1 w := httptest.NewRecorder() searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res := w.Result() defer res.Body.Close() respBody, err := io.ReadAll(res.Body) @@ -413,13 +416,13 @@ func TestDeletionRequests(t *testing.T) { reqBody, err = json.Marshal(delReq) testutils.Check(t, err) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - addDeletionRequestHandler(httptest.NewRecorder(), req) + s.addDeletionRequestHandler(httptest.NewRecorder(), req) // Query again for device id 1 and get a single result time.Sleep(10 * time.Millisecond) w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -447,7 +450,7 @@ func TestDeletionRequests(t *testing.T) { // Query for user 2 w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil) - apiQueryHandler(w, searchReq) + s.apiQueryHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -475,7 +478,7 @@ func TestDeletionRequests(t *testing.T) { // Query for deletion requests w = httptest.NewRecorder() searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - getDeletionRequestsHandler(w, searchReq) + s.getDeletionRequestsHandler(w, searchReq) res = w.Result() defer res.Body.Close() respBody, err = io.ReadAll(res.Body) @@ -504,8 +507,9 @@ func TestDeletionRequests(t *testing.T) { } func TestHealthcheck(t *testing.T) { + s := NewServer(GLOBAL_DB) w := httptest.NewRecorder() - healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) + s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { t.Fatalf("expected 200 resp code for healthCheckHandler") } @@ -524,6 +528,7 @@ func TestHealthcheck(t *testing.T) { func TestLimitRegistrations(t *testing.T) { // Set up InitDB() + s := NewServer(GLOBAL_DB) checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() @@ -531,28 +536,29 @@ func TestLimitRegistrations(t *testing.T) { // Register three devices across two users deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) // And this next one should fail since it is a new user defer func() { _ = recover() }() deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) t.Errorf("expected panic") } func TestCleanDatabaseNoErrors(t *testing.T) { // Init InitDB() + s := NewServer(GLOBAL_DB) // Create a user and an entry userId := data.UserId("dkey") devId1 := uuid.Must(uuid.NewRandom()).String() deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil) - apiRegisterHandler(httptest.NewRecorder(), deviceReq) + s.apiRegisterHandler(httptest.NewRecorder(), deviceReq) entry1 := testutils.MakeFakeHistoryEntry("ls ~/") entry1.DeviceId = devId1 encEntry, err := data.EncryptHistoryEntry("dkey", entry1) @@ -560,7 +566,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry}) testutils.Check(t, err) submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody)) - apiSubmitHandler(httptest.NewRecorder(), submitReq) + s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics testutils.Check(t, GLOBAL_DB.Clean(context.TODO())) diff --git a/backend/server/srv.go b/backend/server/srv.go new file mode 100644 index 0000000..c373167 --- /dev/null +++ b/backend/server/srv.go @@ -0,0 +1,611 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "math" + "net/http" + pprofhttp "net/http/pprof" + "os" + "strconv" + "strings" + "time" + + "github.com/DataDog/datadog-go/statsd" + "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/shared" + "github.com/rodaine/table" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/profiler" +) + +type Srv struct { + db *database.DB + statsd *statsd.Client +} + +type ServerOption func(*Srv) + +func WithStatsd(statsd *statsd.Client) ServerOption { + return func(s *Srv) { + s.statsd = statsd + } +} + +func NewServer(db *database.DB, options ...ServerOption) *Srv { + srv := Srv{db: db} + for _, option := range options { + option(&srv) + } + return &srv +} + +func (s *Srv) Run(ctx context.Context, addr string) error { + mux := httptrace.NewServeMux() + + if isProductionEnvironment() { + defer configureObservability(mux)() + go func() { + if err := s.db.DeepClean(ctx); err != nil { + panic(err) + } + }() + } + loggerMiddleware := withLogging(s.statsd) + + mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) + mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) + mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) + mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) + mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) + mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) + mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) + mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) + mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) + mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) + mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) + mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) + mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) + mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) + mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) + mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) + if isTestEnvironment() { + mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) + mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) + } + + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + } + + fmt.Printf("Listening on %s\n", addr) + if err := httpServer.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http.ListenAndServe: %w", err) + } + } + + return nil +} + +func (s *Srv) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) + if len(entries) == 0 { + return + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) + checkGormError(err, 0) + + if len(devices) == 0 { + panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) + } + fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) + + err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) + if err != nil { + panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) + } + if s.statsd != nil { + s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) + checkGormError(err, 1) + fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } +} + +func (s *Srv) apiQueryHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + // Delete any entries that match a pending deletion request + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + for _, request := range deletionRequests { + _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) + checkGormError(err, 0) + } + + // Then retrieve + historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + checkGormError(err, 0) + fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } + + // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids + // blocking the entire response. This does have a potential race condition, but that is fine. + if isProductionEnvironment() { + go func() { + span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + span.Finish(tracer.WithError(err)) + }() + } else { + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + if err != nil { + panic("failed to increment read counts") + } + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.query", []string{}, 1.0) + } +} + +func (s *Srv) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + srcDeviceId := getRequiredQueryParam(r, "source_device_id") + requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) + + // sanity check + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) + } + } + + err = s.db.EncHistoryCreateMulti(r.Context(), entries...) + checkGormError(err, 0) + err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) + checkGormError(err, 0) + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) apiBannerHandler(w http.ResponseWriter, r *http.Request) { + commitHash := getRequiredQueryParam(r, "commit_hash") + deviceId := getRequiredQueryParam(r, "device_id") + forcedBanner := r.URL.Query().Get("forced_banner") + fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) + if getHishtoryVersion(r) == "v0.160" { + w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) + return + } + w.Write([]byte(html.EscapeString(forcedBanner))) +} + +func (s *Srv) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + var dumpRequests []*shared.DumpRequest + // Filter out ones requested by the hishtory instance that sent this request + dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + + if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Srv) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // Increment the ReadCount + err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) + checkGormError(err, 0) + + // Return all the deletion requests + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Srv) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var request shared.DeletionRequest + + if err := json.Unmarshal(data, &request); err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + request.ReadCount = 0 + fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) + + err = s.db.DeletionRequestCreate(r.Context(), &request) + checkGormError(err, 0) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (_ *Srv) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { + updateInfo := buildUpdateInfo(ReleaseVersion) + resp, err := json.Marshal(updateInfo) + if err != nil { + panic(err) + } + w.Write(resp) +} + +func (s *Srv) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + numDistinctUsers, err := s.db.DistinctUsers(r.Context()) + if err != nil { + panic(fmt.Errorf("db.DistinctUsers: %w", err)) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) + checkGormError(err, 0) + fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) + if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { + checkGormError(err, 0) + } + + if existingDevicesCount > 0 { + err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) + checkGormError(err, 0) + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.register", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) triggerCronHandler(w http.ResponseWriter, r *http.Request) { + err := cron(r.Context()) + if err != nil { + panic(err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { + // returns "OK" unless there is a current SLSA bug + v := getHishtoryVersion(r) + if !strings.Contains(v, "v0.") { + w.Write([]byte("OK")) + return + } + vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) + if err != nil { + w.Write([]byte("OK")) + return + } + if vNum < 159 { + w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) + return + } + w.Write([]byte("OK")) +} + +func (s *Srv) feedbackHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var feedback shared.Feedback + err = json.Unmarshal(data, &feedback) + if err != nil { + panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) + } + fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) + err = s.db.FeedbackCreate(r.Context(), &feedback) + checkGormError(err, 0) + + if s.statsd != nil { + s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + if isProductionEnvironment() { + // Check that we have a reasonable looking set of devices/entries in the DB + //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() + //if err != nil { + // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) + //} + //defer rows.Close() + //if !rows.Next() { + // panic("Suspiciously few enc history entries!") + //} + encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err, 0) + if encHistoryEntryCount < 1000 { + panic("Suspiciously few enc history entries!") + } + + deviceCount, err := s.db.DevicesCount(r.Context()) + checkGormError(err, 0) + if deviceCount < 100 { + panic("Suspiciously few devices!") + } + // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. + err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ + EncryptedData: []byte("data"), + Nonce: []byte("nonce"), + DeviceId: "healthcheck_device_id", + UserId: "healthcheck_user_id", + Date: time.Now(), + EncryptedId: "healthcheck_enc_id", + ReadCount: 10000, + }) + checkGormError(err, 0) + } else { + err := s.db.Ping() + if err != nil { + panic(fmt.Errorf("failed to ping DB: %w", err)) + } + } + w.Write([]byte("OK")) +} + +func (s *Srv) usageStatsHandler(w http.ResponseWriter, r *http.Request) { + usageData, err := s.db.UsageDataStats(r.Context()) + if err != nil { + panic(fmt.Errorf("db.UsageDataStats: %w", err)) + } + + tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") + tbl.WithWriter(w) + for _, data := range usageData { + versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") + lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") + tbl.AddRow( + data.RegistrationDate.Format(shared.DateOnly), + data.NumDevices, + data.NumEntries, + data.NumQueries, + data.LastUsedDate.Format(shared.DateOnly), + lastQueryStr, + versions, + data.IpAddresses, + ) + } + tbl.Print() +} + +func (s *Srv) statsHandler(w http.ResponseWriter, r *http.Request) { + numDevices, err := s.db.DevicesCount(r.Context()) + checkGormError(err, 0) + + numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) + checkGormError(err, 0) + + numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err, 0) + + oneWeek := time.Hour * 24 * 7 + weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) + checkGormError(err, 0) + + weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) + checkGormError(err, 0) + + lastRegistration, err := s.db.LastRegistration(r.Context()) + checkGormError(err, 0) + + _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) + _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) + _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) + _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) + _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) + _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) +} + +func (s *Srv) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { + if r.Host == "api.hishtory.dev" || isProductionEnvironment() { + panic("refusing to wipe the DB for prod") + } + if !isTestEnvironment() { + panic("refusing to wipe the DB non-test environment") + } + + err := s.db.EncHistoryClear(r.Context()) + checkGormError(err, 0) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Srv) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { + stats, err := s.db.Stats() + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) +} + +func (s *Srv) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { + var usageData []shared.UsageData + usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) + if err != nil { + return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) + } + if len(usageData) == 0 { + err := s.db.UsageDataCreate( + ctx, + &shared.UsageData{ + UserId: userId, + DeviceId: deviceId, + LastUsed: time.Now(), + NumEntriesHandled: numEntriesHandled, + Version: version, + }, + ) + if err != nil { + return fmt.Errorf("db.UsageDataCreate: %w", err) + } + } else { + usage := usageData[0] + + if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { + return fmt.Errorf("db.UsageDataUpdate: %w", err) + } + if numEntriesHandled > 0 { + if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) + } + } + if usage.Version != version { + if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { + return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) + } + } + } + if isQuery { + if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) + } + } + + return nil +} + +func configureObservability(mux *httptrace.ServeMux) func() { + // Profiler + err := profiler.Start( + profiler.WithService("hishtory-api"), + profiler.WithVersion(ReleaseVersion), + profiler.WithAPIKey(os.Getenv("DD_API_KEY")), + profiler.WithUDS("/var/run/datadog/apm.socket"), + profiler.WithProfileTypes( + profiler.CPUProfile, + profiler.HeapProfile, + ), + ) + if err != nil { + fmt.Printf("Failed to start DataDog profiler: %v\n", err) + } + // Tracer + tracer.Start( + tracer.WithRuntimeMetrics(), + tracer.WithService("hishtory-api"), + tracer.WithUDS("/var/run/datadog/apm.socket"), + ) + // TODO: should this be here? + defer tracer.Stop() + + // Pprof + mux.HandleFunc("/debug/pprof/", pprofhttp.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) + + // Func to stop all of the above + return func() { + profiler.Stop() + tracer.Stop() + } +} From 60a0e20dd9b026c3f12b2021d6335acb4862bcce Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 09:48:48 -0400 Subject: [PATCH 2/6] extract server object to its own package --- backend/server/server.go | 69 +-- backend/server/server_test.go | 13 +- backend/server/srv.go | 611 --------------------- internal/server/api.go | 238 ++++++++ {backend => internal}/server/middleware.go | 9 +- internal/server/srv.go | 378 +++++++++++++ internal/server/util.go | 93 ++++ 7 files changed, 744 insertions(+), 667 deletions(-) delete mode 100644 backend/server/srv.go create mode 100644 internal/server/api.go rename {backend => internal}/server/middleware.go (99%) create mode 100644 internal/server/srv.go create mode 100644 internal/server/util.go diff --git a/backend/server/server.go b/backend/server/server.go index 4cc8377..bc65d50 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "fmt" + "github.com/ddworken/hishtory/internal/server" "io" "log" - "math" "net/http" "os" "runtime" @@ -33,26 +33,6 @@ var ( ReleaseVersion string = "UNKNOWN" ) -func getRequiredQueryParam(r *http.Request, queryParam string) string { - val := r.URL.Query().Get(queryParam) - if val == "" { - panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam)) - } - return val -} - -func getHishtoryVersion(r *http.Request) string { - return r.Header.Get("X-Hishtory-Version") -} - -func getRemoteAddr(r *http.Request) string { - addr, ok := r.Header["X-Real-Ip"] - if !ok || len(addr) == 0 { - return "UnknownIp" - } - return addr[0] -} - func isTestEnvironment() bool { return os.Getenv("HISHTORY_TEST") != "" } @@ -129,19 +109,17 @@ func init() { go runBackgroundJobs(context.Background()) } -func cron(ctx context.Context) error { - err := updateReleaseVersion() - if err != nil { - panic(err) +func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { + if err := updateReleaseVersion(); err != nil { + return fmt.Errorf("updateReleaseVersion: %w", err) } - err = GLOBAL_DB.Clean(ctx) - if err != nil { - panic(err) + + if err := db.Clean(ctx); err != nil { + return fmt.Errorf("db.Clean: %w", err) } - if GLOBAL_STATSD != nil { - err = GLOBAL_STATSD.Flush() - if err != nil { - panic(err) + if stats != nil { + if err := stats.Flush(); err != nil { + return fmt.Errorf("stats.Flush: %w", err) } } return nil @@ -150,9 +128,12 @@ func cron(ctx context.Context) error { func runBackgroundJobs(ctx context.Context) { time.Sleep(5 * time.Second) for { - err := cron(ctx) + err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD) if err != nil { fmt.Printf("Cron failure: %v", err) + + // cron no longer panics, panicking here. + panic(err) } time.Sleep(10 * time.Minute) } @@ -291,7 +272,15 @@ func main() { // TODO: remove this global once we have a better way to pass it around GLOBAL_STATSD = s - srv := NewServer(GLOBAL_DB, WithStatsd(s)) + srv := server.NewServer( + GLOBAL_DB, + server.WithStatsd(s), + server.WithReleaseVersion(ReleaseVersion), + server.IsTestEnvironment(isTestEnvironment()), + server.IsProductionEnvironment(isProductionEnvironment()), + server.WithCron(cron), + server.WithUpdateInfo(buildUpdateInfo(ReleaseVersion)), + ) if err := srv.Run(context.Background(), ":8080"); err != nil { panic(err) @@ -311,17 +300,5 @@ func checkGormError(err error, skip int) { panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err)) } -func getMaximumNumberOfAllowedUsers() int { - maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS") - if maxNumUsersStr == "" { - return math.MaxInt - } - maxNumUsers, err := strconv.Atoi(maxNumUsersStr) - if err != nil { - return math.MaxInt - } - return maxNumUsers -} - // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? // TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test diff --git a/backend/server/server_test.go b/backend/server/server_test.go index 9187409..3a8b26b 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/internal/server" "github.com/stretchr/testify/require" "io" "net/http" @@ -24,7 +25,7 @@ import ( func TestESubmitThenQuery(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register a few devices userId := data.UserId("key") @@ -125,7 +126,7 @@ func TestESubmitThenQuery(t *testing.T) { func TestDumpRequestAndResponse(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -325,7 +326,7 @@ func TestUpdateReleaseVersion(t *testing.T) { func TestDeletionRequests(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -507,7 +508,7 @@ func TestDeletionRequests(t *testing.T) { } func TestHealthcheck(t *testing.T) { - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) w := httptest.NewRecorder() s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { @@ -528,7 +529,7 @@ func TestHealthcheck(t *testing.T) { func TestLimitRegistrations(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() @@ -552,7 +553,7 @@ func TestLimitRegistrations(t *testing.T) { func TestCleanDatabaseNoErrors(t *testing.T) { // Init InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Create a user and an entry userId := data.UserId("dkey") diff --git a/backend/server/srv.go b/backend/server/srv.go deleted file mode 100644 index c373167..0000000 --- a/backend/server/srv.go +++ /dev/null @@ -1,611 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "html" - "io" - "math" - "net/http" - pprofhttp "net/http/pprof" - "os" - "strconv" - "strings" - "time" - - "github.com/DataDog/datadog-go/statsd" - "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/shared" - "github.com/rodaine/table" - httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/DataDog/dd-trace-go.v1/profiler" -) - -type Srv struct { - db *database.DB - statsd *statsd.Client -} - -type ServerOption func(*Srv) - -func WithStatsd(statsd *statsd.Client) ServerOption { - return func(s *Srv) { - s.statsd = statsd - } -} - -func NewServer(db *database.DB, options ...ServerOption) *Srv { - srv := Srv{db: db} - for _, option := range options { - option(&srv) - } - return &srv -} - -func (s *Srv) Run(ctx context.Context, addr string) error { - mux := httptrace.NewServeMux() - - if isProductionEnvironment() { - defer configureObservability(mux)() - go func() { - if err := s.db.DeepClean(ctx); err != nil { - panic(err) - } - }() - } - loggerMiddleware := withLogging(s.statsd) - - mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) - mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) - mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) - mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) - mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) - mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) - mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) - mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) - mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) - mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) - mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) - mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) - mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) - mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) - mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) - mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) - if isTestEnvironment() { - mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) - mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) - } - - httpServer := &http.Server{ - Addr: addr, - Handler: mux, - } - - fmt.Printf("Listening on %s\n", addr) - if err := httpServer.ListenAndServe(); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("http.ListenAndServe: %w", err) - } - } - - return nil -} - -func (s *Srv) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) - if len(entries) == 0 { - return - } - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) - checkGormError(err, 0) - - if len(devices) == 0 { - panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) - } - fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - - err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) - if err != nil { - panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) - } - if s.statsd != nil { - s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) - checkGormError(err, 1) - fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } -} - -func (s *Srv) apiQueryHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - // Delete any entries that match a pending deletion request - deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - for _, request := range deletionRequests { - _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) - checkGormError(err, 0) - } - - // Then retrieve - historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) - checkGormError(err, 0) - fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } - - // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids - // blocking the entire response. This does have a potential race condition, but that is fine. - if isProductionEnvironment() { - go func() { - span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) - span.Finish(tracer.WithError(err)) - }() - } else { - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) - if err != nil { - panic("failed to increment read counts") - } - } - - if s.statsd != nil { - s.statsd.Incr("hishtory.query", []string{}, 1.0) - } -} - -func (s *Srv) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - srcDeviceId := getRequiredQueryParam(r, "source_device_id") - requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) - - // sanity check - for _, entry := range entries { - entry.DeviceId = requestingDeviceId - if entry.UserId != userId { - panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) - } - } - - err = s.db.EncHistoryCreateMulti(r.Context(), entries...) - checkGormError(err, 0) - err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) - checkGormError(err, 0) - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) apiBannerHandler(w http.ResponseWriter, r *http.Request) { - commitHash := getRequiredQueryParam(r, "commit_hash") - deviceId := getRequiredQueryParam(r, "device_id") - forcedBanner := r.URL.Query().Get("forced_banner") - fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) - if getHishtoryVersion(r) == "v0.160" { - w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) - return - } - w.Write([]byte(html.EscapeString(forcedBanner))) -} - -func (s *Srv) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - var dumpRequests []*shared.DumpRequest - // Filter out ones requested by the hishtory instance that sent this request - dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - - if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func (s *Srv) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // Increment the ReadCount - err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) - checkGormError(err, 0) - - // Return all the deletion requests - deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func (s *Srv) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var request shared.DeletionRequest - - if err := json.Unmarshal(data, &request); err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - request.ReadCount = 0 - fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) - - err = s.db.DeletionRequestCreate(r.Context(), &request) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (_ *Srv) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { - updateInfo := buildUpdateInfo(ReleaseVersion) - resp, err := json.Marshal(updateInfo) - if err != nil { - panic(err) - } - w.Write(resp) -} - -func (s *Srv) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { - if getMaximumNumberOfAllowedUsers() < math.MaxInt { - numDistinctUsers, err := s.db.DistinctUsers(r.Context()) - if err != nil { - panic(fmt.Errorf("db.DistinctUsers: %w", err)) - } - if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { - panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) - } - } - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) - checkGormError(err, 0) - fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { - checkGormError(err, 0) - } - - if existingDevicesCount > 0 { - err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) - checkGormError(err, 0) - } - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - if s.statsd != nil { - s.statsd.Incr("hishtory.register", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) triggerCronHandler(w http.ResponseWriter, r *http.Request) { - err := cron(r.Context()) - if err != nil { - panic(err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { - // returns "OK" unless there is a current SLSA bug - v := getHishtoryVersion(r) - if !strings.Contains(v, "v0.") { - w.Write([]byte("OK")) - return - } - vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) - if err != nil { - w.Write([]byte("OK")) - return - } - if vNum < 159 { - w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) - return - } - w.Write([]byte("OK")) -} - -func (s *Srv) feedbackHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var feedback shared.Feedback - err = json.Unmarshal(data, &feedback) - if err != nil { - panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) - } - fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - err = s.db.FeedbackCreate(r.Context(), &feedback) - checkGormError(err, 0) - - if s.statsd != nil { - s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) healthCheckHandler(w http.ResponseWriter, r *http.Request) { - if isProductionEnvironment() { - // Check that we have a reasonable looking set of devices/entries in the DB - //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() - //if err != nil { - // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) - //} - //defer rows.Close() - //if !rows.Next() { - // panic("Suspiciously few enc history entries!") - //} - encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) - checkGormError(err, 0) - if encHistoryEntryCount < 1000 { - panic("Suspiciously few enc history entries!") - } - - deviceCount, err := s.db.DevicesCount(r.Context()) - checkGormError(err, 0) - if deviceCount < 100 { - panic("Suspiciously few devices!") - } - // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ - EncryptedData: []byte("data"), - Nonce: []byte("nonce"), - DeviceId: "healthcheck_device_id", - UserId: "healthcheck_user_id", - Date: time.Now(), - EncryptedId: "healthcheck_enc_id", - ReadCount: 10000, - }) - checkGormError(err, 0) - } else { - err := s.db.Ping() - if err != nil { - panic(fmt.Errorf("failed to ping DB: %w", err)) - } - } - w.Write([]byte("OK")) -} - -func (s *Srv) usageStatsHandler(w http.ResponseWriter, r *http.Request) { - usageData, err := s.db.UsageDataStats(r.Context()) - if err != nil { - panic(fmt.Errorf("db.UsageDataStats: %w", err)) - } - - tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") - tbl.WithWriter(w) - for _, data := range usageData { - versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") - lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") - tbl.AddRow( - data.RegistrationDate.Format(shared.DateOnly), - data.NumDevices, - data.NumEntries, - data.NumQueries, - data.LastUsedDate.Format(shared.DateOnly), - lastQueryStr, - versions, - data.IpAddresses, - ) - } - tbl.Print() -} - -func (s *Srv) statsHandler(w http.ResponseWriter, r *http.Request) { - numDevices, err := s.db.DevicesCount(r.Context()) - checkGormError(err, 0) - - numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) - checkGormError(err, 0) - - numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) - checkGormError(err, 0) - - oneWeek := time.Hour * 24 * 7 - weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) - checkGormError(err, 0) - - weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) - checkGormError(err, 0) - - lastRegistration, err := s.db.LastRegistration(r.Context()) - checkGormError(err, 0) - - _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) - _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) - _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) - _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) - _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) - _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) -} - -func (s *Srv) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { - if r.Host == "api.hishtory.dev" || isProductionEnvironment() { - panic("refusing to wipe the DB for prod") - } - if !isTestEnvironment() { - panic("refusing to wipe the DB non-test environment") - } - - err := s.db.EncHistoryClear(r.Context()) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { - stats, err := s.db.Stats() - if err != nil { - panic(err) - } - - _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) -} - -func (s *Srv) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { - var usageData []shared.UsageData - usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) - if err != nil { - return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) - } - if len(usageData) == 0 { - err := s.db.UsageDataCreate( - ctx, - &shared.UsageData{ - UserId: userId, - DeviceId: deviceId, - LastUsed: time.Now(), - NumEntriesHandled: numEntriesHandled, - Version: version, - }, - ) - if err != nil { - return fmt.Errorf("db.UsageDataCreate: %w", err) - } - } else { - usage := usageData[0] - - if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { - return fmt.Errorf("db.UsageDataUpdate: %w", err) - } - if numEntriesHandled > 0 { - if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { - return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) - } - } - if usage.Version != version { - if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { - return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) - } - } - } - if isQuery { - if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { - return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) - } - } - - return nil -} - -func configureObservability(mux *httptrace.ServeMux) func() { - // Profiler - err := profiler.Start( - profiler.WithService("hishtory-api"), - profiler.WithVersion(ReleaseVersion), - profiler.WithAPIKey(os.Getenv("DD_API_KEY")), - profiler.WithUDS("/var/run/datadog/apm.socket"), - profiler.WithProfileTypes( - profiler.CPUProfile, - profiler.HeapProfile, - ), - ) - if err != nil { - fmt.Printf("Failed to start DataDog profiler: %v\n", err) - } - // Tracer - tracer.Start( - tracer.WithRuntimeMetrics(), - tracer.WithService("hishtory-api"), - tracer.WithUDS("/var/run/datadog/apm.socket"), - ) - // TODO: should this be here? - defer tracer.Stop() - - // Pprof - mux.HandleFunc("/debug/pprof/", pprofhttp.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) - - // Func to stop all of the above - return func() { - profiler.Stop() - tracer.Stop() - } -} diff --git a/internal/server/api.go b/internal/server/api.go new file mode 100644 index 0000000..bf27b0a --- /dev/null +++ b/internal/server/api.go @@ -0,0 +1,238 @@ +package server + +import ( + "encoding/json" + "fmt" + "github.com/ddworken/hishtory/shared" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "html" + "io" + "math" + "net/http" + "time" +) + +func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) + if len(entries) == 0 { + return + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) + checkGormError(err) + + if len(devices) == 0 { + panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) + } + fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) + + err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) + if err != nil { + panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) + } + if s.statsd != nil { + s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) + checkGormError(err) + fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } +} + +func (s *Server) apiQueryHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + // Delete any entries that match a pending deletion request + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + for _, request := range deletionRequests { + _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) + checkGormError(err) + } + + // Then retrieve + historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + checkGormError(err) + fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } + + // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids + // blocking the entire response. This does have a potential race condition, but that is fine. + if s.isProductionEnvironment { + go func() { + span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + span.Finish(tracer.WithError(err)) + }() + } else { + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + if err != nil { + panic("failed to increment read counts") + } + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.query", []string{}, 1.0) + } +} + +func (s *Server) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + srcDeviceId := getRequiredQueryParam(r, "source_device_id") + requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) + + // sanity check + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) + } + } + + err = s.db.EncHistoryCreateMulti(r.Context(), entries...) + checkGormError(err) + err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) + checkGormError(err) + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) apiBannerHandler(w http.ResponseWriter, r *http.Request) { + commitHash := getRequiredQueryParam(r, "commit_hash") + deviceId := getRequiredQueryParam(r, "device_id") + forcedBanner := r.URL.Query().Get("forced_banner") + fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) + if getHishtoryVersion(r) == "v0.160" { + w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) + return + } + w.Write([]byte(html.EscapeString(forcedBanner))) +} + +func (s *Server) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + var dumpRequests []*shared.DumpRequest + // Filter out ones requested by the hishtory instance that sent this request + dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + + if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Server) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { + err := json.NewEncoder(w).Encode(s.updateInfo) + + if err != nil { + panic(fmt.Errorf("failed to JSON marshall the update info: %w", err)) + } +} + +func (s *Server) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + numDistinctUsers, err := s.db.DistinctUsers(r.Context()) + if err != nil { + panic(fmt.Errorf("db.DistinctUsers: %w", err)) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) + checkGormError(err) + fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) + if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { + checkGormError(err) + } + + if existingDevicesCount > 0 { + err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) + checkGormError(err) + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.register", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} diff --git a/backend/server/middleware.go b/internal/server/middleware.go similarity index 99% rename from backend/server/middleware.go rename to internal/server/middleware.go index 5d8d7e7..986b1dc 100644 --- a/backend/server/middleware.go +++ b/internal/server/middleware.go @@ -1,15 +1,16 @@ -package main +package server import ( "fmt" - "github.com/DataDog/datadog-go/statsd" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "net/http" "reflect" "runtime" "strings" "time" + + "github.com/DataDog/datadog-go/statsd" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) type loggedResponseData struct { diff --git a/internal/server/srv.go b/internal/server/srv.go new file mode 100644 index 0000000..b30ed65 --- /dev/null +++ b/internal/server/srv.go @@ -0,0 +1,378 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/DataDog/datadog-go/statsd" + "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/shared" + "github.com/rodaine/table" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" +) + +type Server struct { + db *database.DB + statsd *statsd.Client + + isProductionEnvironment bool + isTestEnvironment bool + releaseVersion string + cronFn CronFn + updateInfo shared.UpdateInfo +} + +type CronFn func(ctx context.Context, db *database.DB, stats *statsd.Client) error +type Option func(*Server) + +func WithStatsd(statsd *statsd.Client) Option { + return func(s *Server) { + s.statsd = statsd + } +} + +func WithReleaseVersion(releaseVersion string) Option { + return func(s *Server) { + s.releaseVersion = releaseVersion + } +} + +func WithCron(cronFn CronFn) Option { + return func(s *Server) { + s.cronFn = cronFn + } +} + +func WithUpdateInfo(updateInfo shared.UpdateInfo) Option { + return func(s *Server) { + s.updateInfo = updateInfo + } +} + +func IsProductionEnvironment(v bool) Option { + return func(s *Server) { + s.isProductionEnvironment = v + } +} + +func IsTestEnvironment(v bool) Option { + return func(s *Server) { + s.isTestEnvironment = v + } +} + +func NewServer(db *database.DB, options ...Option) *Server { + srv := Server{db: db} + for _, option := range options { + option(&srv) + } + return &srv +} + +func (s *Server) Run(ctx context.Context, addr string) error { + mux := httptrace.NewServeMux() + + if s.isProductionEnvironment { + defer configureObservability(mux, s.releaseVersion)() + go func() { + if err := s.db.DeepClean(ctx); err != nil { + panic(err) + } + }() + } + loggerMiddleware := withLogging(s.statsd) + + mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) + mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) + mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) + mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) + mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) + mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) + mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) + mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) + mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) + mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) + mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) + mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) + mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) + mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) + mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) + mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) + if s.isTestEnvironment { + mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) + mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) + } + + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + } + + fmt.Printf("Listening on %s\n", addr) + if err := httpServer.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http.ListenAndServe: %w", err) + } + } + + return nil +} + +func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // Increment the ReadCount + err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) + checkGormError(err) + + // Return all the deletion requests + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var request shared.DeletionRequest + + if err := json.Unmarshal(data, &request); err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + request.ReadCount = 0 + fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) + + err = s.db.DeletionRequestCreate(r.Context(), &request) + checkGormError(err) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) triggerCronHandler(w http.ResponseWriter, r *http.Request) { + err := s.cronFn(r.Context(), s.db, s.statsd) + if err != nil { + panic(err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { + // returns "OK" unless there is a current SLSA bug + v := getHishtoryVersion(r) + if !strings.Contains(v, "v0.") { + w.Write([]byte("OK")) + return + } + vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) + if err != nil { + w.Write([]byte("OK")) + return + } + if vNum < 159 { + w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) + return + } + w.Write([]byte("OK")) +} + +func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var feedback shared.Feedback + err = json.Unmarshal(data, &feedback) + if err != nil { + panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) + } + fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) + err = s.db.FeedbackCreate(r.Context(), &feedback) + checkGormError(err) + + if s.statsd != nil { + s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + if s.isProductionEnvironment { + // Check that we have a reasonable looking set of devices/entries in the DB + //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() + //if err != nil { + // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) + //} + //defer rows.Close() + //if !rows.Next() { + // panic("Suspiciously few enc history entries!") + //} + encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err) + if encHistoryEntryCount < 1000 { + panic("Suspiciously few enc history entries!") + } + + deviceCount, err := s.db.DevicesCount(r.Context()) + checkGormError(err) + if deviceCount < 100 { + panic("Suspiciously few devices!") + } + // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. + err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ + EncryptedData: []byte("data"), + Nonce: []byte("nonce"), + DeviceId: "healthcheck_device_id", + UserId: "healthcheck_user_id", + Date: time.Now(), + EncryptedId: "healthcheck_enc_id", + ReadCount: 10000, + }) + checkGormError(err) + } else { + err := s.db.Ping() + if err != nil { + panic(fmt.Errorf("failed to ping DB: %w", err)) + } + } + w.Write([]byte("OK")) +} + +func (s *Server) usageStatsHandler(w http.ResponseWriter, r *http.Request) { + usageData, err := s.db.UsageDataStats(r.Context()) + if err != nil { + panic(fmt.Errorf("db.UsageDataStats: %w", err)) + } + + tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") + tbl.WithWriter(w) + for _, data := range usageData { + versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") + lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") + tbl.AddRow( + data.RegistrationDate.Format(shared.DateOnly), + data.NumDevices, + data.NumEntries, + data.NumQueries, + data.LastUsedDate.Format(shared.DateOnly), + lastQueryStr, + versions, + data.IpAddresses, + ) + } + tbl.Print() +} + +func (s *Server) statsHandler(w http.ResponseWriter, r *http.Request) { + numDevices, err := s.db.DevicesCount(r.Context()) + checkGormError(err) + + numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) + checkGormError(err) + + numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err) + + oneWeek := time.Hour * 24 * 7 + weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) + checkGormError(err) + + weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) + checkGormError(err) + + lastRegistration, err := s.db.LastRegistration(r.Context()) + checkGormError(err) + + _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) + _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) + _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) + _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) + _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) + _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) +} + +func (s *Server) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { + if r.Host == "api.hishtory.dev" || s.isProductionEnvironment { + panic("refusing to wipe the DB for prod") + } + if !s.isTestEnvironment { + panic("refusing to wipe the DB non-test environment") + } + + err := s.db.EncHistoryClear(r.Context()) + checkGormError(err) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { + stats, err := s.db.Stats() + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) +} + +func (s *Server) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { + var usageData []shared.UsageData + usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) + if err != nil { + return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) + } + if len(usageData) == 0 { + err := s.db.UsageDataCreate( + ctx, + &shared.UsageData{ + UserId: userId, + DeviceId: deviceId, + LastUsed: time.Now(), + NumEntriesHandled: numEntriesHandled, + Version: version, + }, + ) + if err != nil { + return fmt.Errorf("db.UsageDataCreate: %w", err) + } + } else { + usage := usageData[0] + + if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { + return fmt.Errorf("db.UsageDataUpdate: %w", err) + } + if numEntriesHandled > 0 { + if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) + } + } + if usage.Version != version { + if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { + return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) + } + } + } + if isQuery { + if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) + } + } + + return nil +} diff --git a/internal/server/util.go b/internal/server/util.go new file mode 100644 index 0000000..7c84e8c --- /dev/null +++ b/internal/server/util.go @@ -0,0 +1,93 @@ +package server + +import ( + "fmt" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/profiler" + "math" + "net/http" + pprofhttp "net/http/pprof" + "os" + "runtime" + "strconv" +) + +func getMaximumNumberOfAllowedUsers() int { + maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS") + if maxNumUsersStr == "" { + return math.MaxInt + } + maxNumUsers, err := strconv.Atoi(maxNumUsersStr) + if err != nil { + return math.MaxInt + } + return maxNumUsers +} + +func configureObservability(mux *httptrace.ServeMux, releaseVersion string) func() { + // Profiler + err := profiler.Start( + profiler.WithService("hishtory-api"), + profiler.WithVersion(releaseVersion), + profiler.WithAPIKey(os.Getenv("DD_API_KEY")), + profiler.WithUDS("/var/run/datadog/apm.socket"), + profiler.WithProfileTypes( + profiler.CPUProfile, + profiler.HeapProfile, + ), + ) + if err != nil { + fmt.Printf("Failed to start DataDog profiler: %v\n", err) + } + // Tracer + tracer.Start( + tracer.WithRuntimeMetrics(), + tracer.WithService("hishtory-api"), + tracer.WithUDS("/var/run/datadog/apm.socket"), + ) + // TODO: should this be here? + defer tracer.Stop() + + // Pprof + mux.HandleFunc("/debug/pprof/", pprofhttp.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) + + // Func to stop all of the above + return func() { + profiler.Stop() + tracer.Stop() + } +} + +func getHishtoryVersion(r *http.Request) string { + return r.Header.Get("X-Hishtory-Version") +} + +func getRemoteAddr(r *http.Request) string { + addr, ok := r.Header["X-Real-Ip"] + if !ok || len(addr) == 0 { + return "UnknownIp" + } + return addr[0] +} + +func getRequiredQueryParam(r *http.Request, queryParam string) string { + val := r.URL.Query().Get(queryParam) + if val == "" { + panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam)) + } + return val +} + +func checkGormError(err error) { + if err == nil { + return + } + + _, filename, line, _ := runtime.Caller(1) + panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err)) +} From 0d30011a33622c6bac13e8bdd2cd0bb0ab4b7ca3 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 10:09:38 -0400 Subject: [PATCH 3/6] break down release versions and fix server tests --- backend/server/Dockerfile | 2 +- backend/server/server.go | 130 ++------------------ internal/release/release.go | 127 +++++++++++++++++++ internal/release/release_test.go | 33 +++++ {backend => internal}/server/server_test.go | 97 +++++++-------- internal/server/srv.go | 5 + shared/testutils/testutils.go | 2 +- 7 files changed, 225 insertions(+), 171 deletions(-) create mode 100644 internal/release/release.go create mode 100644 internal/release/release_test.go rename {backend => internal}/server/server_test.go (93%) diff --git a/backend/server/Dockerfile b/backend/server/Dockerfile index 55496c4..909cc61 100644 --- a/backend/server/Dockerfile +++ b/backend/server/Dockerfile @@ -6,7 +6,7 @@ RUN go mod download COPY . ./ ARG GOARCH RUN apk add --update --no-cache --virtual .build-deps build-base && \ - GOARCH=${GOARCH} go build -o /server -ldflags "-X main.ReleaseVersion=v0.`cat VERSION`" backend/server/server.go && \ + GOARCH=${GOARCH} go build -o /server -ldflags "-X release.Version=v0.`cat VERSION`" backend/server/server.go && \ apk del .build-deps FROM alpine:3.17 diff --git a/backend/server/server.go b/backend/server/server.go index bc65d50..9a4e847 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -2,21 +2,16 @@ package main import ( "context" - "encoding/json" "fmt" + "github.com/ddworken/hishtory/internal/release" "github.com/ddworken/hishtory/internal/server" - "io" "log" - "net/http" "os" "runtime" - "strconv" - "strings" "time" "github.com/DataDog/datadog-go/statsd" "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/shared" _ "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -28,9 +23,8 @@ const ( ) var ( - GLOBAL_DB *database.DB - GLOBAL_STATSD *statsd.Client - ReleaseVersion string = "UNKNOWN" + GLOBAL_DB *database.DB + GLOBAL_STATSD *statsd.Client ) func isTestEnvironment() bool { @@ -102,15 +96,14 @@ func OpenDB() (*database.DB, error) { } func init() { - if ReleaseVersion == "UNKNOWN" && !isTestEnvironment() { + if release.Version == "UNKNOWN" && !isTestEnvironment() { panic("server.go was built without a ReleaseVersion!") } InitDB() - go runBackgroundJobs(context.Background()) } func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { - if err := updateReleaseVersion(); err != nil { + if err := release.UpdateReleaseVersion(); err != nil { return fmt.Errorf("updateReleaseVersion: %w", err) } @@ -125,7 +118,7 @@ func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { return nil } -func runBackgroundJobs(ctx context.Context) { +func runBackgroundJobs(ctx context.Context, srv *server.Server) { time.Sleep(5 * time.Second) for { err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD) @@ -135,80 +128,12 @@ func runBackgroundJobs(ctx context.Context) { // cron no longer panics, panicking here. panic(err) } + srv.UpdateReleaseVersion(release.Version, release.BuildUpdateInfo(release.Version)) time.Sleep(10 * time.Minute) } } -type releaseInfo struct { - Name string `json:"name"` -} - -func updateReleaseVersion() error { - resp, err := http.Get("https://api.github.com/repos/ddworken/hishtory/releases/latest") - if err != nil { - return fmt.Errorf("failed to get latest release version: %w", err) - } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read github API response body: %w", err) - } - if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") { - return nil - } - if resp.StatusCode != 200 { - return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody)) - } - var info releaseInfo - err = json.Unmarshal(respBody, &info) - if err != nil { - return fmt.Errorf("failed to parse github API response: %w", err) - } - latestVersionTag := info.Name - ReleaseVersion = decrementVersionIfInvalid(latestVersionTag) - return nil -} - -func decrementVersionIfInvalid(initialVersion string) string { - // Decrements the version up to 5 times if the version doesn't have valid binaries yet. - version := initialVersion - for i := 0; i < 5; i++ { - updateInfo := buildUpdateInfo(version) - err := assertValidUpdate(updateInfo) - if err == nil { - fmt.Printf("Found a valid version: %v\n", version) - return version - } - fmt.Printf("Found %s to be an invalid version: %v\n", version, err) - version, err = decrementVersion(version) - if err != nil { - fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err) - return initialVersion - } - } - fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version) - return initialVersion -} - -func assertValidUpdate(updateInfo shared.UpdateInfo) error { - urls := []string{updateInfo.LinuxAmd64Url, updateInfo.LinuxAmd64AttestationUrl, updateInfo.LinuxArm64Url, updateInfo.LinuxArm64AttestationUrl, - updateInfo.LinuxArm7Url, updateInfo.LinuxArm7AttestationUrl, - updateInfo.DarwinAmd64Url, updateInfo.DarwinAmd64UnsignedUrl, updateInfo.DarwinAmd64AttestationUrl, - updateInfo.DarwinArm64Url, updateInfo.DarwinArm64UnsignedUrl, updateInfo.DarwinArm64AttestationUrl} - for _, url := range urls { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to retrieve URL %#v: %w", url, err) - } - defer resp.Body.Close() - if resp.StatusCode == 404 { - return fmt.Errorf("URL %#v returned 404", url) - } - } - return nil -} - -func InitDB() { +func InitDB() *database.DB { var err error GLOBAL_DB, err = OpenDB() if err != nil { @@ -228,39 +153,8 @@ func InitDB() { panic(fmt.Errorf("failed to set max idle conns: %w", err)) } } -} -func decrementVersion(version string) (string, error) { - if version == "UNKNOWN" { - return "", fmt.Errorf("cannot decrement UNKNOWN") - } - parts := strings.Split(version, ".") - if len(parts) != 2 { - return "", fmt.Errorf("invalid version: %s", version) - } - versionNumber, err := strconv.Atoi(parts[1]) - if err != nil { - return "", fmt.Errorf("invalid version: %s", version) - } - return parts[0] + "." + strconv.Itoa(versionNumber-1), nil -} - -func buildUpdateInfo(version string) shared.UpdateInfo { - return shared.UpdateInfo{ - LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version), - LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version), - LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version), - LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version), - LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version), - LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version), - DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version), - DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version), - DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version), - DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version), - DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version), - DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version), - Version: version, - } + return GLOBAL_DB } func main() { @@ -275,13 +169,15 @@ func main() { srv := server.NewServer( GLOBAL_DB, server.WithStatsd(s), - server.WithReleaseVersion(ReleaseVersion), + server.WithReleaseVersion(release.Version), server.IsTestEnvironment(isTestEnvironment()), server.IsProductionEnvironment(isProductionEnvironment()), server.WithCron(cron), - server.WithUpdateInfo(buildUpdateInfo(ReleaseVersion)), + server.WithUpdateInfo(release.BuildUpdateInfo(release.Version)), ) + go runBackgroundJobs(context.Background(), srv) + if err := srv.Run(context.Background(), ":8080"); err != nil { panic(err) } diff --git a/internal/release/release.go b/internal/release/release.go new file mode 100644 index 0000000..36b03c3 --- /dev/null +++ b/internal/release/release.go @@ -0,0 +1,127 @@ +package release + +import ( + "encoding/json" + "fmt" + "github.com/ddworken/hishtory/shared" + "io" + "net/http" + "strconv" + "strings" +) + +var Version = "UNKNOWN" + +type releaseInfo struct { + Name string `json:"name"` +} + +const releaseURL = "https://api.github.com/repos/ddworken/hishtory/releases/latest" + +func UpdateReleaseVersion() error { + resp, err := http.Get(releaseURL) + if err != nil { + return fmt.Errorf("failed to get latest release version: %w", err) + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read github API response body: %w", err) + } + if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") { + return nil + } + if resp.StatusCode != 200 { + return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody)) + } + var info releaseInfo + err = json.Unmarshal(respBody, &info) + if err != nil { + return fmt.Errorf("failed to parse github API response: %w", err) + } + latestVersionTag := info.Name + Version = decrementVersionIfInvalid(latestVersionTag) + return nil +} + +func BuildUpdateInfo(version string) shared.UpdateInfo { + return shared.UpdateInfo{ + LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version), + LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version), + LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version), + LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version), + LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version), + LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version), + DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version), + DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version), + DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version), + DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version), + DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version), + DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version), + Version: version, + } +} + +func decrementVersionIfInvalid(initialVersion string) string { + // Decrements the version up to 5 times if the version doesn't have valid binaries yet. + version := initialVersion + for i := 0; i < 5; i++ { + updateInfo := BuildUpdateInfo(version) + err := assertValidUpdate(updateInfo) + if err == nil { + fmt.Printf("Found a valid version: %v\n", version) + return version + } + fmt.Printf("Found %s to be an invalid version: %v\n", version, err) + version, err = decrementVersion(version) + if err != nil { + fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err) + return initialVersion + } + } + fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version) + return initialVersion +} + +func assertValidUpdate(updateInfo shared.UpdateInfo) error { + urls := []string{ + updateInfo.LinuxAmd64Url, + updateInfo.LinuxAmd64AttestationUrl, + updateInfo.LinuxArm64Url, + updateInfo.LinuxArm64AttestationUrl, + updateInfo.LinuxArm7Url, + updateInfo.LinuxArm7AttestationUrl, + updateInfo.DarwinAmd64Url, + updateInfo.DarwinAmd64UnsignedUrl, + updateInfo.DarwinAmd64AttestationUrl, + updateInfo.DarwinArm64Url, + updateInfo.DarwinArm64UnsignedUrl, + updateInfo.DarwinArm64AttestationUrl, + } + for _, url := range urls { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to retrieve URL %#v: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return fmt.Errorf("URL %#v returned 404", url) + } + } + return nil +} + +func decrementVersion(version string) (string, error) { + if version == "UNKNOWN" { + return "", fmt.Errorf("cannot decrement UNKNOWN") + } + parts := strings.Split(version, ".") + if len(parts) != 2 { + return "", fmt.Errorf("invalid version: %s", version) + } + versionNumber, err := strconv.Atoi(parts[1]) + if err != nil { + return "", fmt.Errorf("invalid version: %s", version) + } + return parts[0] + "." + strconv.Itoa(versionNumber-1), nil +} diff --git a/internal/release/release_test.go b/internal/release/release_test.go new file mode 100644 index 0000000..5cba118 --- /dev/null +++ b/internal/release/release_test.go @@ -0,0 +1,33 @@ +package release + +import ( + "github.com/ddworken/hishtory/shared/testutils" + "strings" + "testing" +) + +func TestUpdateReleaseVersion(t *testing.T) { + if !testutils.IsOnline() { + t.Skip("skipping because we're currently offline") + } + + // Check that ReleaseVersion hasn't been set yet + if Version != "UNKNOWN" { + t.Fatalf("initial ReleaseVersion isn't as expected: %#v", Version) + } + + // Update it + err := UpdateReleaseVersion() + if err != nil { + t.Fatalf("updateReleaseVersion failed: %v", err) + } + + // If ReleaseVersion is still unknown, skip because we're getting rate limited + if Version == "UNKNOWN" { + t.Skip() + } + // Otherwise, check that the new value looks reasonable + if !strings.HasPrefix(Version, "v0.") { + t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", Version) + } +} diff --git a/backend/server/server_test.go b/internal/server/server_test.go similarity index 93% rename from backend/server/server_test.go rename to internal/server/server_test.go index 3a8b26b..19dfbf8 100644 --- a/backend/server/server_test.go +++ b/internal/server/server_test.go @@ -1,12 +1,13 @@ -package main +package server import ( "bytes" "context" "encoding/json" + "fmt" "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/internal/server" "github.com/stretchr/testify/require" + "gorm.io/gorm" "io" "net/http" "net/http/httptest" @@ -22,10 +23,32 @@ import ( "github.com/google/uuid" ) +var DB *database.DB + +const testDBDSN = "file::memory:?_journal_mode=WAL&cache=shared" + +func TestMain(m *testing.M) { + // setup test database + db, err := database.OpenSQLite(testDBDSN, &gorm.Config{}) + if err != nil { + panic(fmt.Errorf("failed to connect to the DB: %w", err)) + } + underlyingDb, err := db.DB.DB() + if err != nil { + panic(fmt.Errorf("failed to access underlying DB: %w", err)) + } + underlyingDb.SetMaxOpenConns(1) + db.Exec("PRAGMA journal_mode = WAL") + db.AddDatabaseTables() + + DB = db + + os.Exit(m.Run()) +} + func TestESubmitThenQuery(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register a few devices userId := data.UserId("key") @@ -120,13 +143,12 @@ func TestESubmitThenQuery(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestDumpRequestAndResponse(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -288,45 +310,12 @@ func TestDumpRequestAndResponse(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) -} - -func TestUpdateReleaseVersion(t *testing.T) { - if !testutils.IsOnline() { - t.Skip("skipping because we're currently offline") - } - - // Set up - InitDB() - - // Check that ReleaseVersion hasn't been set yet - if ReleaseVersion != "UNKNOWN" { - t.Fatalf("initial ReleaseVersion isn't as expected: %#v", ReleaseVersion) - } - - // Update it - err := updateReleaseVersion() - if err != nil { - t.Fatalf("updateReleaseVersion failed: %v", err) - } - - // If ReleaseVersion is still unknown, skip because we're getting rate limited - if ReleaseVersion == "UNKNOWN" { - t.Skip() - } - // Otherwise, check that the new value looks reasonable - if !strings.HasPrefix(ReleaseVersion, "v0.") { - t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", ReleaseVersion) - } - - // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestDeletionRequests(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -504,11 +493,11 @@ func TestDeletionRequests(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestHealthcheck(t *testing.T) { - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) w := httptest.NewRecorder() s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { @@ -523,15 +512,20 @@ func TestHealthcheck(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestLimitRegistrations(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) - checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) - checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) + s := NewServer(DB) + + if resp := DB.Exec("DELETE FROM enc_history_entries"); resp.Error != nil { + t.Fatalf("failed to delete enc_history_entries: %v", resp.Error) + } + + if resp := DB.Exec("DELETE FROM devices"); resp.Error != nil { + t.Fatalf("failed to delete devices: %v", resp.Error) + } defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() os.Setenv("HISHTORY_MAX_NUM_USERS", "2") @@ -552,8 +546,7 @@ func TestLimitRegistrations(t *testing.T) { func TestCleanDatabaseNoErrors(t *testing.T) { // Init - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Create a user and an entry userId := data.UserId("dkey") @@ -570,7 +563,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics - testutils.Check(t, GLOBAL_DB.Clean(context.TODO())) + testutils.Check(t, DB.Clean(context.TODO())) } func assertNoLeakedConnections(t *testing.T, db *database.DB) { diff --git a/internal/server/srv.go b/internal/server/srv.go index b30ed65..31b9d9b 100644 --- a/internal/server/srv.go +++ b/internal/server/srv.go @@ -125,6 +125,11 @@ func (s *Server) Run(ctx context.Context, addr string) error { return nil } +func (s *Server) UpdateReleaseVersion(v string, updateInfo shared.UpdateInfo) { + s.releaseVersion = v + s.updateInfo = updateInfo +} + func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") diff --git a/shared/testutils/testutils.go b/shared/testutils/testutils.go index 77ec18a..38e5a7f 100644 --- a/shared/testutils/testutils.go +++ b/shared/testutils/testutils.go @@ -228,7 +228,7 @@ func buildServer() string { f, err := os.CreateTemp("", "server") checkError(err) fn := f.Name() - cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X main.ReleaseVersion=v0.%s", version), "backend/server/server.go") + cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X release.Version=v0.%s", version), "backend/server/server.go") var stdout bytes.Buffer cmd.Stdout = &stdout var stderr bytes.Buffer From a8360efa67ba3c7e6970954e9ec045cc26050321 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 15:55:41 -0400 Subject: [PATCH 4/6] revert main.ReleaseVersion changes --- backend/server/Dockerfile | 2 +- backend/server/server.go | 10 ++++++---- shared/testutils/testutils.go | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/backend/server/Dockerfile b/backend/server/Dockerfile index 909cc61..55496c4 100644 --- a/backend/server/Dockerfile +++ b/backend/server/Dockerfile @@ -6,7 +6,7 @@ RUN go mod download COPY . ./ ARG GOARCH RUN apk add --update --no-cache --virtual .build-deps build-base && \ - GOARCH=${GOARCH} go build -o /server -ldflags "-X release.Version=v0.`cat VERSION`" backend/server/server.go && \ + GOARCH=${GOARCH} go build -o /server -ldflags "-X main.ReleaseVersion=v0.`cat VERSION`" backend/server/server.go && \ apk del .build-deps FROM alpine:3.17 diff --git a/backend/server/server.go b/backend/server/server.go index 9a4e847..3a1498b 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -3,8 +3,6 @@ package main import ( "context" "fmt" - "github.com/ddworken/hishtory/internal/release" - "github.com/ddworken/hishtory/internal/server" "log" "os" "runtime" @@ -12,6 +10,8 @@ import ( "github.com/DataDog/datadog-go/statsd" "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/internal/release" + "github.com/ddworken/hishtory/internal/server" _ "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -23,8 +23,9 @@ const ( ) var ( - GLOBAL_DB *database.DB - GLOBAL_STATSD *statsd.Client + GLOBAL_DB *database.DB + GLOBAL_STATSD *statsd.Client + ReleaseVersion string ) func isTestEnvironment() bool { @@ -96,6 +97,7 @@ func OpenDB() (*database.DB, error) { } func init() { + release.Version = ReleaseVersion if release.Version == "UNKNOWN" && !isTestEnvironment() { panic("server.go was built without a ReleaseVersion!") } diff --git a/shared/testutils/testutils.go b/shared/testutils/testutils.go index 38e5a7f..77ec18a 100644 --- a/shared/testutils/testutils.go +++ b/shared/testutils/testutils.go @@ -228,7 +228,7 @@ func buildServer() string { f, err := os.CreateTemp("", "server") checkError(err) fn := f.Name() - cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X release.Version=v0.%s", version), "backend/server/server.go") + cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X main.ReleaseVersion=v0.%s", version), "backend/server/server.go") var stdout bytes.Buffer cmd.Stdout = &stdout var stderr bytes.Buffer From 7638751bd62495541ea7dea965a20ead19906bc4 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 15:56:05 -0400 Subject: [PATCH 5/6] fix functions with changed names --- internal/database/db.go | 1 + internal/server/api.go | 21 +++++++++++---------- internal/server/srv.go | 37 ++++++++++++++----------------------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/internal/database/db.go b/internal/database/db.go index 93c03b3..5218c8c 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/ddworken/hishtory/shared" "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" diff --git a/internal/server/api.go b/internal/server/api.go index bf27b0a..fdda038 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -3,13 +3,14 @@ package server import ( "encoding/json" "fmt" - "github.com/ddworken/hishtory/shared" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "html" "io" "math" "net/http" "time" + + "github.com/ddworken/hishtory/shared" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { @@ -43,7 +44,7 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { } fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) + err = s.db.AddHistoryEntriesForAllDevices(r.Context(), devices, entries) if err != nil { panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) } @@ -66,7 +67,7 @@ func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { fmt.Printf("updateUsageData: %v\n", err) } - historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) + historyEntries, err := s.db.AllHistoryEntriesForUser(r.Context(), userId) checkGormError(err) fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) if err := json.NewEncoder(w).Encode(historyEntries); err != nil { @@ -96,7 +97,7 @@ func (s *Server) apiQueryHandler(w http.ResponseWriter, r *http.Request) { } // Then retrieve - historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + historyEntries, err := s.db.HistoryEntriesForDevice(r.Context(), deviceId, 5) checkGormError(err) fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) if err := json.NewEncoder(w).Encode(historyEntries); err != nil { @@ -108,11 +109,11 @@ func (s *Server) apiQueryHandler(w http.ResponseWriter, r *http.Request) { if s.isProductionEnvironment { go func() { span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + err := s.db.IncrementEntryReadCountsForDevice(ctx, deviceId) span.Finish(tracer.WithError(err)) }() } else { - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + err := s.db.IncrementEntryReadCountsForDevice(ctx, deviceId) if err != nil { panic("failed to increment read counts") } @@ -146,7 +147,7 @@ func (s *Server) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { } } - err = s.db.EncHistoryCreateMulti(r.Context(), entries...) + err = s.db.AddHistoryEntries(r.Context(), entries...) checkGormError(err) err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) checkGormError(err) @@ -209,10 +210,10 @@ func (s *Server) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") - existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) + existingDevicesCount, err := s.db.CountDevicesForUser(r.Context(), userId) checkGormError(err) fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { + if err := s.db.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { checkGormError(err) } diff --git a/internal/server/srv.go b/internal/server/srv.go index 31b9d9b..a12032b 100644 --- a/internal/server/srv.go +++ b/internal/server/srv.go @@ -219,28 +219,19 @@ func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) healthCheckHandler(w http.ResponseWriter, r *http.Request) { if s.isProductionEnvironment { - // Check that we have a reasonable looking set of devices/entries in the DB - //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() - //if err != nil { - // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) - //} - //defer rows.Close() - //if !rows.Next() { - // panic("Suspiciously few enc history entries!") - //} - encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) + encHistoryEntryCount, err := s.db.CountHistoryEntries(r.Context()) checkGormError(err) if encHistoryEntryCount < 1000 { panic("Suspiciously few enc history entries!") } - deviceCount, err := s.db.DevicesCount(r.Context()) + deviceCount, err := s.db.CountAllDevices(r.Context()) checkGormError(err) if deviceCount < 100 { panic("Suspiciously few devices!") } // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ + err = s.db.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{ EncryptedData: []byte("data"), Nonce: []byte("nonce"), DeviceId: "healthcheck_device_id", @@ -285,23 +276,23 @@ func (s *Server) usageStatsHandler(w http.ResponseWriter, r *http.Request) { } func (s *Server) statsHandler(w http.ResponseWriter, r *http.Request) { - numDevices, err := s.db.DevicesCount(r.Context()) + numDevices, err := s.db.CountAllDevices(r.Context()) checkGormError(err) numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) checkGormError(err) - numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) + numDbEntries, err := s.db.CountHistoryEntries(r.Context()) checkGormError(err) oneWeek := time.Hour * 24 * 7 - weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) + weeklyActiveInstalls, err := s.db.CountActiveInstalls(r.Context(), oneWeek) checkGormError(err) - weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) + weeklyQueryUsers, err := s.db.CountQueryUsers(r.Context(), oneWeek) checkGormError(err) - lastRegistration, err := s.db.LastRegistration(r.Context()) + lastRegistration, err := s.db.DateOfLastRegistration(r.Context()) checkGormError(err) _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) @@ -320,7 +311,7 @@ func (s *Server) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { panic("refusing to wipe the DB non-test environment") } - err := s.db.EncHistoryClear(r.Context()) + err := s.db.Unsafe_DeleteAllHistoryEntries(r.Context()) checkGormError(err) w.Header().Set("Content-Length", "0") @@ -343,7 +334,7 @@ func (s *Server) updateUsageData(ctx context.Context, version string, remoteAddr return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) } if len(usageData) == 0 { - err := s.db.UsageDataCreate( + err := s.db.CreateUsageData( ctx, &shared.UsageData{ UserId: userId, @@ -359,22 +350,22 @@ func (s *Server) updateUsageData(ctx context.Context, version string, remoteAddr } else { usage := usageData[0] - if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { + if err := s.db.UpdateUsageData(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { return fmt.Errorf("db.UsageDataUpdate: %w", err) } if numEntriesHandled > 0 { - if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { + if err := s.db.UpdateUsageDataForNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) } } if usage.Version != version { - if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { + if err := s.db.UpdateUsageDataClientVersion(ctx, userId, deviceId, version); err != nil { return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) } } } if isQuery { - if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { + if err := s.db.UpdateUsageDataNumberQueries(ctx, userId, deviceId); err != nil { return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) } } From b93a365055d35d2d6fb87fe1dba26231b1a51710 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Wed, 13 Sep 2023 10:35:18 -0400 Subject: [PATCH 6/6] use actions/checkout@v4 for go-test.yml --- .github/workflows/go-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index eceab0d..dfb1e41 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -16,7 +16,7 @@ jobs: os: [ubuntu-latest, macos-latest] fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v3 with: