From 60a0e20dd9b026c3f12b2021d6335acb4862bcce Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 09:48:48 -0400 Subject: [PATCH] extract server object to its own package --- backend/server/server.go | 69 +-- backend/server/server_test.go | 13 +- backend/server/srv.go | 611 --------------------- internal/server/api.go | 238 ++++++++ {backend => internal}/server/middleware.go | 9 +- internal/server/srv.go | 378 +++++++++++++ internal/server/util.go | 93 ++++ 7 files changed, 744 insertions(+), 667 deletions(-) delete mode 100644 backend/server/srv.go create mode 100644 internal/server/api.go rename {backend => internal}/server/middleware.go (99%) create mode 100644 internal/server/srv.go create mode 100644 internal/server/util.go diff --git a/backend/server/server.go b/backend/server/server.go index 4cc8377..bc65d50 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -4,9 +4,9 @@ import ( "context" "encoding/json" "fmt" + "github.com/ddworken/hishtory/internal/server" "io" "log" - "math" "net/http" "os" "runtime" @@ -33,26 +33,6 @@ var ( ReleaseVersion string = "UNKNOWN" ) -func getRequiredQueryParam(r *http.Request, queryParam string) string { - val := r.URL.Query().Get(queryParam) - if val == "" { - panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam)) - } - return val -} - -func getHishtoryVersion(r *http.Request) string { - return r.Header.Get("X-Hishtory-Version") -} - -func getRemoteAddr(r *http.Request) string { - addr, ok := r.Header["X-Real-Ip"] - if !ok || len(addr) == 0 { - return "UnknownIp" - } - return addr[0] -} - func isTestEnvironment() bool { return os.Getenv("HISHTORY_TEST") != "" } @@ -129,19 +109,17 @@ func init() { go runBackgroundJobs(context.Background()) } -func cron(ctx context.Context) error { - err := updateReleaseVersion() - if err != nil { - panic(err) +func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { + if err := updateReleaseVersion(); err != nil { + return fmt.Errorf("updateReleaseVersion: %w", err) } - err = GLOBAL_DB.Clean(ctx) - if err != nil { - panic(err) + + if err := db.Clean(ctx); err != nil { + return fmt.Errorf("db.Clean: %w", err) } - if GLOBAL_STATSD != nil { - err = GLOBAL_STATSD.Flush() - if err != nil { - panic(err) + if stats != nil { + if err := stats.Flush(); err != nil { + return fmt.Errorf("stats.Flush: %w", err) } } return nil @@ -150,9 +128,12 @@ func cron(ctx context.Context) error { func runBackgroundJobs(ctx context.Context) { time.Sleep(5 * time.Second) for { - err := cron(ctx) + err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD) if err != nil { fmt.Printf("Cron failure: %v", err) + + // cron no longer panics, panicking here. + panic(err) } time.Sleep(10 * time.Minute) } @@ -291,7 +272,15 @@ func main() { // TODO: remove this global once we have a better way to pass it around GLOBAL_STATSD = s - srv := NewServer(GLOBAL_DB, WithStatsd(s)) + srv := server.NewServer( + GLOBAL_DB, + server.WithStatsd(s), + server.WithReleaseVersion(ReleaseVersion), + server.IsTestEnvironment(isTestEnvironment()), + server.IsProductionEnvironment(isProductionEnvironment()), + server.WithCron(cron), + server.WithUpdateInfo(buildUpdateInfo(ReleaseVersion)), + ) if err := srv.Run(context.Background(), ":8080"); err != nil { panic(err) @@ -311,17 +300,5 @@ func checkGormError(err error, skip int) { panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err)) } -func getMaximumNumberOfAllowedUsers() int { - maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS") - if maxNumUsersStr == "" { - return math.MaxInt - } - maxNumUsers, err := strconv.Atoi(maxNumUsersStr) - if err != nil { - return math.MaxInt - } - return maxNumUsers -} - // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? // TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test diff --git a/backend/server/server_test.go b/backend/server/server_test.go index 9187409..3a8b26b 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/internal/server" "github.com/stretchr/testify/require" "io" "net/http" @@ -24,7 +25,7 @@ import ( func TestESubmitThenQuery(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register a few devices userId := data.UserId("key") @@ -125,7 +126,7 @@ func TestESubmitThenQuery(t *testing.T) { func TestDumpRequestAndResponse(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -325,7 +326,7 @@ func TestUpdateReleaseVersion(t *testing.T) { func TestDeletionRequests(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -507,7 +508,7 @@ func TestDeletionRequests(t *testing.T) { } func TestHealthcheck(t *testing.T) { - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) w := httptest.NewRecorder() s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { @@ -528,7 +529,7 @@ func TestHealthcheck(t *testing.T) { func TestLimitRegistrations(t *testing.T) { // Set up InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() @@ -552,7 +553,7 @@ func TestLimitRegistrations(t *testing.T) { func TestCleanDatabaseNoErrors(t *testing.T) { // Init InitDB() - s := NewServer(GLOBAL_DB) + s := server.NewServer(GLOBAL_DB) // Create a user and an entry userId := data.UserId("dkey") diff --git a/backend/server/srv.go b/backend/server/srv.go deleted file mode 100644 index c373167..0000000 --- a/backend/server/srv.go +++ /dev/null @@ -1,611 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "html" - "io" - "math" - "net/http" - pprofhttp "net/http/pprof" - "os" - "strconv" - "strings" - "time" - - "github.com/DataDog/datadog-go/statsd" - "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/shared" - "github.com/rodaine/table" - httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/DataDog/dd-trace-go.v1/profiler" -) - -type Srv struct { - db *database.DB - statsd *statsd.Client -} - -type ServerOption func(*Srv) - -func WithStatsd(statsd *statsd.Client) ServerOption { - return func(s *Srv) { - s.statsd = statsd - } -} - -func NewServer(db *database.DB, options ...ServerOption) *Srv { - srv := Srv{db: db} - for _, option := range options { - option(&srv) - } - return &srv -} - -func (s *Srv) Run(ctx context.Context, addr string) error { - mux := httptrace.NewServeMux() - - if isProductionEnvironment() { - defer configureObservability(mux)() - go func() { - if err := s.db.DeepClean(ctx); err != nil { - panic(err) - } - }() - } - loggerMiddleware := withLogging(s.statsd) - - mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) - mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) - mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) - mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) - mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) - mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) - mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) - mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) - mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) - mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) - mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) - mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) - mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) - mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) - mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) - mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) - if isTestEnvironment() { - mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) - mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) - } - - httpServer := &http.Server{ - Addr: addr, - Handler: mux, - } - - fmt.Printf("Listening on %s\n", addr) - if err := httpServer.ListenAndServe(); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("http.ListenAndServe: %w", err) - } - } - - return nil -} - -func (s *Srv) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) - if len(entries) == 0 { - return - } - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) - checkGormError(err, 0) - - if len(devices) == 0 { - panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) - } - fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - - err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) - if err != nil { - panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) - } - if s.statsd != nil { - s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) - checkGormError(err, 1) - fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } -} - -func (s *Srv) apiQueryHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - // Delete any entries that match a pending deletion request - deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - for _, request := range deletionRequests { - _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) - checkGormError(err, 0) - } - - // Then retrieve - historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) - checkGormError(err, 0) - fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) - if err := json.NewEncoder(w).Encode(historyEntries); err != nil { - panic(err) - } - - // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids - // blocking the entire response. This does have a potential race condition, but that is fine. - if isProductionEnvironment() { - go func() { - span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) - span.Finish(tracer.WithError(err)) - }() - } else { - err := s.db.DeviceIncrementReadCounts(ctx, deviceId) - if err != nil { - panic("failed to increment read counts") - } - } - - if s.statsd != nil { - s.statsd.Incr("hishtory.query", []string{}, 1.0) - } -} - -func (s *Srv) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - srcDeviceId := getRequiredQueryParam(r, "source_device_id") - requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var entries []*shared.EncHistoryEntry - err = json.Unmarshal(data, &entries) - if err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) - - // sanity check - for _, entry := range entries { - entry.DeviceId = requestingDeviceId - if entry.UserId != userId { - panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) - } - } - - err = s.db.EncHistoryCreateMulti(r.Context(), entries...) - checkGormError(err, 0) - err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) - checkGormError(err, 0) - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) apiBannerHandler(w http.ResponseWriter, r *http.Request) { - commitHash := getRequiredQueryParam(r, "commit_hash") - deviceId := getRequiredQueryParam(r, "device_id") - forcedBanner := r.URL.Query().Get("forced_banner") - fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) - if getHishtoryVersion(r) == "v0.160" { - w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) - return - } - w.Write([]byte(html.EscapeString(forcedBanner))) -} - -func (s *Srv) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - var dumpRequests []*shared.DumpRequest - // Filter out ones requested by the hishtory instance that sent this request - dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - - if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func (s *Srv) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - // Increment the ReadCount - err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) - checkGormError(err, 0) - - // Return all the deletion requests - deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) - checkGormError(err, 0) - if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { - panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) - } -} - -func (s *Srv) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var request shared.DeletionRequest - - if err := json.Unmarshal(data, &request); err != nil { - panic(fmt.Sprintf("body=%#v, err=%v", data, err)) - } - request.ReadCount = 0 - fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) - - err = s.db.DeletionRequestCreate(r.Context(), &request) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (_ *Srv) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { - updateInfo := buildUpdateInfo(ReleaseVersion) - resp, err := json.Marshal(updateInfo) - if err != nil { - panic(err) - } - w.Write(resp) -} - -func (s *Srv) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { - if getMaximumNumberOfAllowedUsers() < math.MaxInt { - numDistinctUsers, err := s.db.DistinctUsers(r.Context()) - if err != nil { - panic(fmt.Errorf("db.DistinctUsers: %w", err)) - } - if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { - panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) - } - } - userId := getRequiredQueryParam(r, "user_id") - deviceId := getRequiredQueryParam(r, "device_id") - - existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) - checkGormError(err, 0) - fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) - if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { - checkGormError(err, 0) - } - - if existingDevicesCount > 0 { - err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) - checkGormError(err, 0) - } - - // TODO: add these to the context in a middleware - version := getHishtoryVersion(r) - remoteIPAddr := getRemoteAddr(r) - - if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { - fmt.Printf("updateUsageData: %v\n", err) - } - - if s.statsd != nil { - s.statsd.Incr("hishtory.register", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) triggerCronHandler(w http.ResponseWriter, r *http.Request) { - err := cron(r.Context()) - if err != nil { - panic(err) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { - // returns "OK" unless there is a current SLSA bug - v := getHishtoryVersion(r) - if !strings.Contains(v, "v0.") { - w.Write([]byte("OK")) - return - } - vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) - if err != nil { - w.Write([]byte("OK")) - return - } - if vNum < 159 { - w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) - return - } - w.Write([]byte("OK")) -} - -func (s *Srv) feedbackHandler(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - var feedback shared.Feedback - err = json.Unmarshal(data, &feedback) - if err != nil { - panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) - } - fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - err = s.db.FeedbackCreate(r.Context(), &feedback) - checkGormError(err, 0) - - if s.statsd != nil { - s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) - } - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) healthCheckHandler(w http.ResponseWriter, r *http.Request) { - if isProductionEnvironment() { - // Check that we have a reasonable looking set of devices/entries in the DB - //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() - //if err != nil { - // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) - //} - //defer rows.Close() - //if !rows.Next() { - // panic("Suspiciously few enc history entries!") - //} - encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) - checkGormError(err, 0) - if encHistoryEntryCount < 1000 { - panic("Suspiciously few enc history entries!") - } - - deviceCount, err := s.db.DevicesCount(r.Context()) - checkGormError(err, 0) - if deviceCount < 100 { - panic("Suspiciously few devices!") - } - // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ - EncryptedData: []byte("data"), - Nonce: []byte("nonce"), - DeviceId: "healthcheck_device_id", - UserId: "healthcheck_user_id", - Date: time.Now(), - EncryptedId: "healthcheck_enc_id", - ReadCount: 10000, - }) - checkGormError(err, 0) - } else { - err := s.db.Ping() - if err != nil { - panic(fmt.Errorf("failed to ping DB: %w", err)) - } - } - w.Write([]byte("OK")) -} - -func (s *Srv) usageStatsHandler(w http.ResponseWriter, r *http.Request) { - usageData, err := s.db.UsageDataStats(r.Context()) - if err != nil { - panic(fmt.Errorf("db.UsageDataStats: %w", err)) - } - - tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") - tbl.WithWriter(w) - for _, data := range usageData { - versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") - lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") - tbl.AddRow( - data.RegistrationDate.Format(shared.DateOnly), - data.NumDevices, - data.NumEntries, - data.NumQueries, - data.LastUsedDate.Format(shared.DateOnly), - lastQueryStr, - versions, - data.IpAddresses, - ) - } - tbl.Print() -} - -func (s *Srv) statsHandler(w http.ResponseWriter, r *http.Request) { - numDevices, err := s.db.DevicesCount(r.Context()) - checkGormError(err, 0) - - numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) - checkGormError(err, 0) - - numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) - checkGormError(err, 0) - - oneWeek := time.Hour * 24 * 7 - weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) - checkGormError(err, 0) - - weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) - checkGormError(err, 0) - - lastRegistration, err := s.db.LastRegistration(r.Context()) - checkGormError(err, 0) - - _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) - _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) - _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) - _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) - _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) - _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) -} - -func (s *Srv) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { - if r.Host == "api.hishtory.dev" || isProductionEnvironment() { - panic("refusing to wipe the DB for prod") - } - if !isTestEnvironment() { - panic("refusing to wipe the DB non-test environment") - } - - err := s.db.EncHistoryClear(r.Context()) - checkGormError(err, 0) - - w.Header().Set("Content-Length", "0") - w.WriteHeader(http.StatusOK) -} - -func (s *Srv) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { - stats, err := s.db.Stats() - if err != nil { - panic(err) - } - - _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) -} - -func (s *Srv) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { - var usageData []shared.UsageData - usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) - if err != nil { - return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) - } - if len(usageData) == 0 { - err := s.db.UsageDataCreate( - ctx, - &shared.UsageData{ - UserId: userId, - DeviceId: deviceId, - LastUsed: time.Now(), - NumEntriesHandled: numEntriesHandled, - Version: version, - }, - ) - if err != nil { - return fmt.Errorf("db.UsageDataCreate: %w", err) - } - } else { - usage := usageData[0] - - if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { - return fmt.Errorf("db.UsageDataUpdate: %w", err) - } - if numEntriesHandled > 0 { - if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { - return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) - } - } - if usage.Version != version { - if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { - return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) - } - } - } - if isQuery { - if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { - return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) - } - } - - return nil -} - -func configureObservability(mux *httptrace.ServeMux) func() { - // Profiler - err := profiler.Start( - profiler.WithService("hishtory-api"), - profiler.WithVersion(ReleaseVersion), - profiler.WithAPIKey(os.Getenv("DD_API_KEY")), - profiler.WithUDS("/var/run/datadog/apm.socket"), - profiler.WithProfileTypes( - profiler.CPUProfile, - profiler.HeapProfile, - ), - ) - if err != nil { - fmt.Printf("Failed to start DataDog profiler: %v\n", err) - } - // Tracer - tracer.Start( - tracer.WithRuntimeMetrics(), - tracer.WithService("hishtory-api"), - tracer.WithUDS("/var/run/datadog/apm.socket"), - ) - // TODO: should this be here? - defer tracer.Stop() - - // Pprof - mux.HandleFunc("/debug/pprof/", pprofhttp.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) - - // Func to stop all of the above - return func() { - profiler.Stop() - tracer.Stop() - } -} diff --git a/internal/server/api.go b/internal/server/api.go new file mode 100644 index 0000000..bf27b0a --- /dev/null +++ b/internal/server/api.go @@ -0,0 +1,238 @@ +package server + +import ( + "encoding/json" + "fmt" + "github.com/ddworken/hishtory/shared" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "html" + "io" + "math" + "net/http" + "time" +) + +func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries)) + if len(entries) == 0 { + return + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId) + checkGormError(err) + + if len(devices) == 0 { + panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) + } + fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) + + err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) + if err != nil { + panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) + } + if s.statsd != nil { + s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId) + checkGormError(err) + fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } +} + +func (s *Server) apiQueryHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + // Delete any entries that match a pending deletion request + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + for _, request := range deletionRequests { + _, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request) + checkGormError(err) + } + + // Then retrieve + historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + checkGormError(err) + fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { + panic(err) + } + + // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids + // blocking the entire response. This does have a potential race condition, but that is fine. + if s.isProductionEnvironment { + go func() { + span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + span.Finish(tracer.WithError(err)) + }() + } else { + err := s.db.DeviceIncrementReadCounts(ctx, deviceId) + if err != nil { + panic("failed to increment read counts") + } + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.query", []string{}, 1.0) + } +} + +func (s *Server) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + srcDeviceId := getRequiredQueryParam(r, "source_device_id") + requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id") + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var entries []*shared.EncHistoryEntry + err = json.Unmarshal(data, &entries) + if err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) + + // sanity check + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) + } + } + + err = s.db.EncHistoryCreateMulti(r.Context(), entries...) + checkGormError(err) + err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) + checkGormError(err) + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) apiBannerHandler(w http.ResponseWriter, r *http.Request) { + commitHash := getRequiredQueryParam(r, "commit_hash") + deviceId := getRequiredQueryParam(r, "device_id") + forcedBanner := r.URL.Query().Get("forced_banner") + fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner) + if getHishtoryVersion(r) == "v0.160" { + w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory.")) + return + } + w.Write([]byte(html.EscapeString(forcedBanner))) +} + +func (s *Server) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + var dumpRequests []*shared.DumpRequest + // Filter out ones requested by the hishtory instance that sent this request + dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + + if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Server) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { + err := json.NewEncoder(w).Encode(s.updateInfo) + + if err != nil { + panic(fmt.Errorf("failed to JSON marshall the update info: %w", err)) + } +} + +func (s *Server) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + numDistinctUsers, err := s.db.DistinctUsers(r.Context()) + if err != nil { + panic(fmt.Errorf("db.DistinctUsers: %w", err)) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId) + checkGormError(err) + fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount) + if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil { + checkGormError(err) + } + + if existingDevicesCount > 0 { + err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()}) + checkGormError(err) + } + + // TODO: add these to the context in a middleware + version := getHishtoryVersion(r) + remoteIPAddr := getRemoteAddr(r) + + if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + if s.statsd != nil { + s.statsd.Incr("hishtory.register", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} diff --git a/backend/server/middleware.go b/internal/server/middleware.go similarity index 99% rename from backend/server/middleware.go rename to internal/server/middleware.go index 5d8d7e7..986b1dc 100644 --- a/backend/server/middleware.go +++ b/internal/server/middleware.go @@ -1,15 +1,16 @@ -package main +package server import ( "fmt" - "github.com/DataDog/datadog-go/statsd" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "net/http" "reflect" "runtime" "strings" "time" + + "github.com/DataDog/datadog-go/statsd" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" ) type loggedResponseData struct { diff --git a/internal/server/srv.go b/internal/server/srv.go new file mode 100644 index 0000000..b30ed65 --- /dev/null +++ b/internal/server/srv.go @@ -0,0 +1,378 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/DataDog/datadog-go/statsd" + "github.com/ddworken/hishtory/internal/database" + "github.com/ddworken/hishtory/shared" + "github.com/rodaine/table" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" +) + +type Server struct { + db *database.DB + statsd *statsd.Client + + isProductionEnvironment bool + isTestEnvironment bool + releaseVersion string + cronFn CronFn + updateInfo shared.UpdateInfo +} + +type CronFn func(ctx context.Context, db *database.DB, stats *statsd.Client) error +type Option func(*Server) + +func WithStatsd(statsd *statsd.Client) Option { + return func(s *Server) { + s.statsd = statsd + } +} + +func WithReleaseVersion(releaseVersion string) Option { + return func(s *Server) { + s.releaseVersion = releaseVersion + } +} + +func WithCron(cronFn CronFn) Option { + return func(s *Server) { + s.cronFn = cronFn + } +} + +func WithUpdateInfo(updateInfo shared.UpdateInfo) Option { + return func(s *Server) { + s.updateInfo = updateInfo + } +} + +func IsProductionEnvironment(v bool) Option { + return func(s *Server) { + s.isProductionEnvironment = v + } +} + +func IsTestEnvironment(v bool) Option { + return func(s *Server) { + s.isTestEnvironment = v + } +} + +func NewServer(db *database.DB, options ...Option) *Server { + srv := Server{db: db} + for _, option := range options { + option(&srv) + } + return &srv +} + +func (s *Server) Run(ctx context.Context, addr string) error { + mux := httptrace.NewServeMux() + + if s.isProductionEnvironment { + defer configureObservability(mux, s.releaseVersion)() + go func() { + if err := s.db.DeepClean(ctx); err != nil { + panic(err) + } + }() + } + loggerMiddleware := withLogging(s.statsd) + + mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler)) + mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler)) + mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler)) + mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler)) + mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler)) + mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler)) + mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler)) + mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler)) + mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler)) + mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler)) + mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler)) + mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler)) + mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler)) + mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler)) + mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler)) + mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler)) + if s.isTestEnvironment { + mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler)) + mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler)) + } + + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + } + + fmt.Printf("Listening on %s\n", addr) + if err := httpServer.ListenAndServe(); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http.ListenAndServe: %w", err) + } + } + + return nil +} + +func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { + userId := getRequiredQueryParam(r, "user_id") + deviceId := getRequiredQueryParam(r, "device_id") + + // Increment the ReadCount + err := s.db.DeletionRequestInc(r.Context(), userId, deviceId) + checkGormError(err) + + // Return all the deletion requests + deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err) + if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { + panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) + } +} + +func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var request shared.DeletionRequest + + if err := json.Unmarshal(data, &request); err != nil { + panic(fmt.Sprintf("body=%#v, err=%v", data, err)) + } + request.ReadCount = 0 + fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) + + err = s.db.DeletionRequestCreate(r.Context(), &request) + checkGormError(err) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) triggerCronHandler(w http.ResponseWriter, r *http.Request) { + err := s.cronFn(r.Context(), s.db, s.statsd) + if err != nil { + panic(err) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) { + // returns "OK" unless there is a current SLSA bug + v := getHishtoryVersion(r) + if !strings.Contains(v, "v0.") { + w.Write([]byte("OK")) + return + } + vNum, err := strconv.Atoi(strings.Split(v, ".")[1]) + if err != nil { + w.Write([]byte("OK")) + return + } + if vNum < 159 { + w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163")) + return + } + w.Write([]byte("OK")) +} + +func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + var feedback shared.Feedback + err = json.Unmarshal(data, &feedback) + if err != nil { + panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) + } + fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) + err = s.db.FeedbackCreate(r.Context(), &feedback) + checkGormError(err) + + if s.statsd != nil { + s.statsd.Incr("hishtory.uninstall", []string{}, 1.0) + } + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) healthCheckHandler(w http.ResponseWriter, r *http.Request) { + if s.isProductionEnvironment { + // Check that we have a reasonable looking set of devices/entries in the DB + //rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() + //if err != nil { + // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) + //} + //defer rows.Close() + //if !rows.Next() { + // panic("Suspiciously few enc history entries!") + //} + encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err) + if encHistoryEntryCount < 1000 { + panic("Suspiciously few enc history entries!") + } + + deviceCount, err := s.db.DevicesCount(r.Context()) + checkGormError(err) + if deviceCount < 100 { + panic("Suspiciously few devices!") + } + // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. + err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ + EncryptedData: []byte("data"), + Nonce: []byte("nonce"), + DeviceId: "healthcheck_device_id", + UserId: "healthcheck_user_id", + Date: time.Now(), + EncryptedId: "healthcheck_enc_id", + ReadCount: 10000, + }) + checkGormError(err) + } else { + err := s.db.Ping() + if err != nil { + panic(fmt.Errorf("failed to ping DB: %w", err)) + } + } + w.Write([]byte("OK")) +} + +func (s *Server) usageStatsHandler(w http.ResponseWriter, r *http.Request) { + usageData, err := s.db.UsageDataStats(r.Context()) + if err != nil { + panic(fmt.Errorf("db.UsageDataStats: %w", err)) + } + + tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") + tbl.WithWriter(w) + for _, data := range usageData { + versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") + lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "") + tbl.AddRow( + data.RegistrationDate.Format(shared.DateOnly), + data.NumDevices, + data.NumEntries, + data.NumQueries, + data.LastUsedDate.Format(shared.DateOnly), + lastQueryStr, + versions, + data.IpAddresses, + ) + } + tbl.Print() +} + +func (s *Server) statsHandler(w http.ResponseWriter, r *http.Request) { + numDevices, err := s.db.DevicesCount(r.Context()) + checkGormError(err) + + numEntriesProcessed, err := s.db.UsageDataTotal(r.Context()) + checkGormError(err) + + numDbEntries, err := s.db.EncHistoryEntryCount(r.Context()) + checkGormError(err) + + oneWeek := time.Hour * 24 * 7 + weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek) + checkGormError(err) + + weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek) + checkGormError(err) + + lastRegistration, err := s.db.LastRegistration(r.Context()) + checkGormError(err) + + _, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices) + _, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed) + _, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries) + _, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls) + _, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers) + _, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration) +} + +func (s *Server) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { + if r.Host == "api.hishtory.dev" || s.isProductionEnvironment { + panic("refusing to wipe the DB for prod") + } + if !s.isTestEnvironment { + panic("refusing to wipe the DB non-test environment") + } + + err := s.db.EncHistoryClear(r.Context()) + checkGormError(err) + + w.Header().Set("Content-Length", "0") + w.WriteHeader(http.StatusOK) +} + +func (s *Server) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { + stats, err := s.db.Stats() + if err != nil { + panic(err) + } + + _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) +} + +func (s *Server) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error { + var usageData []shared.UsageData + usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId) + if err != nil { + return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) + } + if len(usageData) == 0 { + err := s.db.UsageDataCreate( + ctx, + &shared.UsageData{ + UserId: userId, + DeviceId: deviceId, + LastUsed: time.Now(), + NumEntriesHandled: numEntriesHandled, + Version: version, + }, + ) + if err != nil { + return fmt.Errorf("db.UsageDataCreate: %w", err) + } + } else { + usage := usageData[0] + + if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil { + return fmt.Errorf("db.UsageDataUpdate: %w", err) + } + if numEntriesHandled > 0 { + if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) + } + } + if usage.Version != version { + if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil { + return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) + } + } + } + if isQuery { + if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) + } + } + + return nil +} diff --git a/internal/server/util.go b/internal/server/util.go new file mode 100644 index 0000000..7c84e8c --- /dev/null +++ b/internal/server/util.go @@ -0,0 +1,93 @@ +package server + +import ( + "fmt" + httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/DataDog/dd-trace-go.v1/profiler" + "math" + "net/http" + pprofhttp "net/http/pprof" + "os" + "runtime" + "strconv" +) + +func getMaximumNumberOfAllowedUsers() int { + maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS") + if maxNumUsersStr == "" { + return math.MaxInt + } + maxNumUsers, err := strconv.Atoi(maxNumUsersStr) + if err != nil { + return math.MaxInt + } + return maxNumUsers +} + +func configureObservability(mux *httptrace.ServeMux, releaseVersion string) func() { + // Profiler + err := profiler.Start( + profiler.WithService("hishtory-api"), + profiler.WithVersion(releaseVersion), + profiler.WithAPIKey(os.Getenv("DD_API_KEY")), + profiler.WithUDS("/var/run/datadog/apm.socket"), + profiler.WithProfileTypes( + profiler.CPUProfile, + profiler.HeapProfile, + ), + ) + if err != nil { + fmt.Printf("Failed to start DataDog profiler: %v\n", err) + } + // Tracer + tracer.Start( + tracer.WithRuntimeMetrics(), + tracer.WithService("hishtory-api"), + tracer.WithUDS("/var/run/datadog/apm.socket"), + ) + // TODO: should this be here? + defer tracer.Stop() + + // Pprof + mux.HandleFunc("/debug/pprof/", pprofhttp.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace) + + // Func to stop all of the above + return func() { + profiler.Stop() + tracer.Stop() + } +} + +func getHishtoryVersion(r *http.Request) string { + return r.Header.Get("X-Hishtory-Version") +} + +func getRemoteAddr(r *http.Request) string { + addr, ok := r.Header["X-Real-Ip"] + if !ok || len(addr) == 0 { + return "UnknownIp" + } + return addr[0] +} + +func getRequiredQueryParam(r *http.Request, queryParam string) string { + val := r.URL.Query().Get(queryParam) + if val == "" { + panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam)) + } + return val +} + +func checkGormError(err error) { + if err == nil { + return + } + + _, filename, line, _ := runtime.Caller(1) + panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err)) +}