diff --git a/backend/server/server.go b/backend/server/server.go index 45e1509..19aaa6f 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -166,26 +166,19 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) { if len(entries) == 0 { return } - updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false) - tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", entries[0].UserId) - var devices []*shared.Device - checkGormResult(tx.Find(&devices)) + if err := updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil { + fmt.Printf("updateUsageData: %v\n", err) + } + + devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId) + checkGormError(err, 0) + if len(devices) == 0 { panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId)) } fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices)) - err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error { - for _, device := range devices { - for _, entry := range entries { - entry.DeviceId = device.DeviceId - } - // Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error - for _, entriesChunk := range shared.Chunks(entries, 1000) { - checkGormResult(tx.Create(&entriesChunk)) - } - } - return nil - }) + + err = GLOBAL_DB.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000) if err != nil { panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err)) } @@ -201,15 +194,12 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") updateUsageData(r, userId, deviceId, 0, false) - tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", userId) - var historyEntries []*shared.EncHistoryEntry - checkGormResult(tx.Find(&historyEntries)) + historyEntries, err := GLOBAL_DB.EncHistoryEntriesForUser(r.Context(), userId) + checkGormError(err, 1) fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries)) - resp, err := json.Marshal(historyEntries) - if err != nil { + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { panic(err) } - w.Write(resp) } func apiQueryHandler(w http.ResponseWriter, r *http.Request) { @@ -219,36 +209,31 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) { updateUsageData(r, userId, deviceId, 0, true) // Delete any entries that match a pending deletion request - var deletionRequests []*shared.DeletionRequest - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests)) + deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) for _, request := range deletionRequests { - _, err := applyDeletionRequestsToBackend(r.Context(), *request) - if err != nil { - panic(err) - } + _, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request) + checkGormError(err, 0) } // Then retrieve - tx := GLOBAL_DB.WithContext(r.Context()).Where("device_id = ? AND read_count < 5", deviceId) - var historyEntries []*shared.EncHistoryEntry - checkGormResult(tx.Find(&historyEntries)) + historyEntries, err := GLOBAL_DB.EncHistoryEntriesForDevice(r.Context(), deviceId, 5) + checkGormError(err, 0) fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL) - resp, err := json.Marshal(historyEntries) - if err != nil { + if err := json.NewEncoder(w).Encode(historyEntries); err != nil { panic(err) } - w.Write(resp) // And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids // blocking the entire response. This does have a potential race condition, but that is fine. if isProductionEnvironment() { go func() { span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount") - err = incrementReadCounts(ctx, deviceId) + err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId) span.Finish(tracer.WithError(err)) }() } else { - err = incrementReadCounts(ctx, deviceId) + err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId) if err != nil { panic("failed to increment read counts") } @@ -259,10 +244,6 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) { } } -func incrementReadCounts(ctx context.Context, deviceId string) error { - return GLOBAL_DB.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId).Error -} - func getRemoteAddr(r *http.Request) string { addr, ok := r.Header["X-Real-Ip"] if !ok || len(addr) == 0 { @@ -312,12 +293,12 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) { deviceId := getRequiredQueryParam(r, "device_id") var dumpRequests []*shared.DumpRequest // Filter out ones requested by the hishtory instance that sent this request - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests)) - respBody, err := json.Marshal(dumpRequests) - if err != nil { + dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + + if err := json.NewEncoder(w).Encode(dumpRequests); err != nil { panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) } - w.Write(respBody) } func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { @@ -328,26 +309,25 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) { if err != nil { panic(err) } - var entries []shared.EncHistoryEntry + var entries []*shared.EncHistoryEntry err = json.Unmarshal(data, &entries) if err != nil { panic(fmt.Sprintf("body=%#v, err=%v", data, err)) } fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries)) - err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error { - for _, entry := range entries { - entry.DeviceId = requestingDeviceId - if entry.UserId != userId { - return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId) - } - checkGormResult(tx.Create(&entry)) + + // sanity check + for _, entry := range entries { + entry.DeviceId = requestingDeviceId + if entry.UserId != userId { + panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)) } - return nil - }) - if err != nil { - panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err)) } - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId)) + + err = GLOBAL_DB.EncHistoryCreateMulti(r.Context(), entries...) + checkGormError(err, 0) + err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId) + checkGormError(err, 0) updateUsageData(r, userId, srcDeviceId, len(entries), false) w.Header().Set("Content-Length", "0") @@ -371,16 +351,15 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { deviceId := getRequiredQueryParam(r, "device_id") // Increment the ReadCount - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId)) + err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId) + checkGormError(err, 0) // Return all the deletion requests - var deletionRequests []*shared.DeletionRequest - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests)) - respBody, err := json.Marshal(deletionRequests) - if err != nil { + deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId) + checkGormError(err, 0) + if err := json.NewEncoder(w).Encode(deletionRequests); err != nil { panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err)) } - w.Write(respBody) } func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { @@ -389,32 +368,15 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { panic(err) } var request shared.DeletionRequest - err = json.Unmarshal(data, &request) - if err != nil { + + if err := json.Unmarshal(data, &request); err != nil { panic(fmt.Sprintf("body=%#v, err=%v", data, err)) } request.ReadCount = 0 fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids)) - // Store the deletion request so all the devices will get it - tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", request.UserId) - var devices []*shared.Device - checkGormResult(tx.Find(&devices)) - if len(devices) == 0 { - panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId)) - } - fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices)) - for _, device := range devices { - request.DestinationDeviceId = device.DeviceId - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&request)) - } - - // Also delete anything currently in the DB matching it - numDeleted, err := applyDeletionRequestsToBackend(r.Context(), request) - if err != nil { - panic(err) - } - fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted) + err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request) + checkGormError(err, 0) w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusOK) @@ -423,21 +385,27 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) { func healthCheckHandler(w http.ResponseWriter, r *http.Request) { if isProductionEnvironment() { // Check that we have a reasonable looking set of devices/entries in the DB - rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() - if err != nil { - panic(fmt.Sprintf("failed to count entries in DB: %v", err)) - } - defer rows.Close() - if !rows.Next() { + //rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows() + //if err != nil { + // panic(fmt.Sprintf("failed to count entries in DB: %v", err)) + //} + //defer rows.Close() + //if !rows.Next() { + // panic("Suspiciously few enc history entries!") + //} + encHistoryEntryCount, err := GLOBAL_DB.EncHistoryEntryCount(r.Context()) + checkGormError(err, 0) + if encHistoryEntryCount < 1000 { panic("Suspiciously few enc history entries!") } - var count int64 - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&count)) - if count < 100 { + + deviceCount, err := GLOBAL_DB.DevicesCount(r.Context()) + checkGormError(err, 0) + if deviceCount < 100 { panic("Suspiciously few devices!") } // Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron. - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.EncHistoryEntry{ + err = GLOBAL_DB.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{ EncryptedData: []byte("data"), Nonce: []byte("nonce"), DeviceId: "healthcheck_device_id", @@ -445,7 +413,8 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) { Date: time.Now(), EncryptedId: "healthcheck_enc_id", ReadCount: 10000, - })) + }) + checkGormError(err, 0) } else { err := GLOBAL_DB.Ping() if err != nil { @@ -455,16 +424,6 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) } -func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) { - tx := GLOBAL_DB.WithContext(ctx).Where("false") - for _, message := range request.Messages.Ids { - tx = tx.Or(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date)) - } - result := tx.Delete(&shared.EncHistoryEntry{}) - checkGormResult(result) - return int(result.RowsAffected), nil -} - func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { if r.Host == "api.hishtory.dev" || isProductionEnvironment() { panic("refusing to wipe the DB for prod") @@ -472,7 +431,9 @@ func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { if !isTestEnvironment() { panic("refusing to wipe the DB non-test environment") } - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries")) + + err := GLOBAL_DB.EncHistoryClear(r.Context()) + checkGormError(err, 0) w.Header().Set("Content-Length", "0") w.WriteHeader(http.StatusOK) @@ -562,7 +523,7 @@ func cron(ctx context.Context) error { if err != nil { panic(err) } - err = cleanDatabase(ctx) + err = GLOBAL_DB.Clean(ctx) if err != nil { panic(err) } @@ -669,21 +630,21 @@ func InitDB() { var err error GLOBAL_DB, err = OpenDB() if err != nil { - panic(err) - } - sqlDb, err := GLOBAL_DB.DB.DB() - if err != nil { - panic(err) + panic(fmt.Errorf("OpenDB: %w", err)) } if err := GLOBAL_DB.Ping(); err != nil { panic(fmt.Errorf("ping: %w", err)) } if isProductionEnvironment() { - sqlDb.SetMaxIdleConns(10) + if err := GLOBAL_DB.SetMaxIdleConns(10); err != nil { + panic(fmt.Errorf("failed to set max idle conns: %w", err)) + } } if isTestEnvironment() { - sqlDb.SetMaxIdleConns(1) + if err := GLOBAL_DB.SetMaxIdleConns(1); err != nil { + panic(fmt.Errorf("failed to set max idle conns: %w", err)) + } } } @@ -759,7 +720,8 @@ func feedbackHandler(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err)) } fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback) - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback)) + err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback) + checkGormError(err, 0) if GLOBAL_STATSD != nil { GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0) @@ -834,58 +796,6 @@ func byteCountToString(b int) string { return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp]) } -func cleanDatabase(ctx context.Context) error { - r := GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10") - if r.Error != nil { - return r.Error - } - r = GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100") - if r.Error != nil { - return r.Error - } - return nil -} - -func deepCleanDatabase(ctx context.Context) { - err := GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - r := tx.Exec(` - CREATE TEMP TABLE temp_users_with_one_device AS ( - SELECT user_id - FROM devices - GROUP BY user_id - HAVING COUNT(DISTINCT device_id) > 1 - ) - `) - if r.Error != nil { - return r.Error - } - r = tx.Exec(` - CREATE TEMP TABLE temp_inactive_users AS ( - SELECT user_id - FROM usage_data - WHERE last_used <= (now() - INTERVAL '90 days') - ) - `) - if r.Error != nil { - return r.Error - } - r = tx.Exec(` - SELECT COUNT(*) FROM enc_history_entries WHERE - date <= (now() - INTERVAL '90 days') - AND user_id IN (SELECT * FROM temp_users_with_one_device) - AND user_id IN (SELECT * FROM temp_inactive_users) - `) - if r.Error != nil { - return r.Error - } - fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected) - return nil - }) - if err != nil { - panic(fmt.Errorf("failed to deep clean DB: %w", err)) - } -} - func configureObservability(mux *httptrace.ServeMux) func() { // Profiler err := profiler.Start( @@ -933,7 +843,11 @@ func main() { if isProductionEnvironment() { defer configureObservability(mux)() - go deepCleanDatabase(context.Background()) + go func() { + if err := GLOBAL_DB.DeepClean(context.Background()); err != nil { + panic(err) + } + }() } mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler)) diff --git a/backend/server/server_test.go b/backend/server/server_test.go index 8efe178..bd1d93f 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "github.com/ddworken/hishtory/internal/database" "io" "net/http" "net/http/httptest" @@ -17,7 +18,6 @@ import ( "github.com/ddworken/hishtory/shared/testutils" "github.com/go-test/deep" "github.com/google/uuid" - "gorm.io/gorm" ) func TestESubmitThenQuery(t *testing.T) { @@ -564,15 +564,15 @@ func TestCleanDatabaseNoErrors(t *testing.T) { apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics - testutils.Check(t, cleanDatabase(context.TODO())) + testutils.Check(t, GLOBAL_DB.Clean(context.TODO())) } -func assertNoLeakedConnections(t *testing.T, db *gorm.DB) { - sqlDB, err := db.DB() +func assertNoLeakedConnections(t *testing.T, db *database.DB) { + stats, err := db.Stats() if err != nil { t.Fatal(err) } - numConns := sqlDB.Stats().OpenConnections + numConns := stats.OpenConnections if numConns > 1 { t.Fatalf("expected DB to have not leak connections, actually have %d", numConns) } diff --git a/internal/database/db.go b/internal/database/db.go index e727dec..93c03b3 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -86,6 +86,17 @@ func (db *DB) Ping() error { return nil } +func (db *DB) SetMaxIdleConns(n int) error { + rawDB, err := db.DB.DB() + if err != nil { + return err + } + + rawDB.SetMaxIdleConns(n) + + return nil +} + func (db *DB) Stats() (sql.DBStats, error) { rawDB, err := db.DB.DB() if err != nil { @@ -106,35 +117,6 @@ func (db *DB) DistinctUsers(ctx context.Context) (int64, error) { return numDistinctUsers, nil } -func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) { - var existingDevicesCount int64 - tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount) - if tx.Error != nil { - return 0, fmt.Errorf("tx.Error: %w", tx.Error) - } - - return existingDevicesCount, nil -} - -func (db *DB) DevicesCount(ctx context.Context) (int64, error) { - var numDevices int64 = 0 - tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices) - if tx.Error != nil { - return 0, fmt.Errorf("tx.Error: %w", tx.Error) - } - - return numDevices, nil -} - -func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error { - tx := db.WithContext(ctx).Create(device) - if tx.Error != nil { - return fmt.Errorf("tx.Error: %w", tx.Error) - } - - return nil -} - func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) error { tx := db.WithContext(ctx).Create(req) if tx.Error != nil { @@ -144,12 +126,144 @@ func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) er return nil } -func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) { - var numDbEntries int64 - tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries) +func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DumpRequest, error) { + var dumpRequests []*shared.DumpRequest + // Filter out ones requested by the hishtory instance that sent this request + tx := db.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userID, deviceID).Find(&dumpRequests) + if tx.Error != nil { + return nil, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return dumpRequests, nil +} + +func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, deviceID string) error { + tx := db.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userID, deviceID) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +} + +func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, request *shared.DeletionRequest) (int64, error) { + tx := db.WithContext(ctx).Where("false") + for _, message := range request.Messages.Ids { + tx = tx.Or(db.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date)) + } + result := tx.Delete(&shared.EncHistoryEntry{}) if tx.Error != nil { return 0, fmt.Errorf("tx.Error: %w", tx.Error) } - - return numDbEntries, nil + return result.RowsAffected, nil +} + +func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) error { + tx := db.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE user_id = ? AND destination_device_id = ?", userID, deviceID) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +} + +func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DeletionRequest, error) { + var deletionRequests []*shared.DeletionRequest + tx := db.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userID, deviceID).Find(&deletionRequests) + if tx.Error != nil { + return nil, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return deletionRequests, nil +} + +func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.DeletionRequest) error { + userID := request.UserId + + devices, err := db.DevicesForUser(ctx, userID) + if err != nil { + return fmt.Errorf("db.DevicesForUser: %w", err) + } + + if len(devices) == 0 { + return fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", userID) + } + + fmt.Printf("db.DeletionRequestCreate: Found %d devices\n", len(devices)) + + // TODO: maybe this should be a transaction? + for _, device := range devices { + request.DestinationDeviceId = device.DeviceId + tx := db.WithContext(ctx).Create(&request) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + } + + numDeleted, err := db.ApplyDeletionRequestsToBackend(ctx, request) + if err != nil { + return fmt.Errorf("db.ApplyDeletionRequestsToBackend: %w", err) + } + fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted) + + return nil +} + +func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) error { + tx := db.WithContext(ctx).Create(feedback) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +} + +func (db *DB) Clean(ctx context.Context) error { + r := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10") + if r.Error != nil { + return r.Error + } + r = db.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100") + if r.Error != nil { + return r.Error + } + + return nil +} + +func (db *DB) DeepClean(ctx context.Context) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + r := tx.Exec(` + CREATE TEMP TABLE temp_users_with_one_device AS ( + SELECT user_id + FROM devices + GROUP BY user_id + HAVING COUNT(DISTINCT device_id) > 1 + ) + `) + if r.Error != nil { + return r.Error + } + r = tx.Exec(` + CREATE TEMP TABLE temp_inactive_users AS ( + SELECT user_id + FROM usage_data + WHERE last_used <= (now() - INTERVAL '90 days') + ) + `) + if r.Error != nil { + return r.Error + } + r = tx.Exec(` + SELECT COUNT(*) FROM enc_history_entries WHERE + date <= (now() - INTERVAL '90 days') + AND user_id IN (SELECT * FROM temp_users_with_one_device) + AND user_id IN (SELECT * FROM temp_inactive_users) + `) + if r.Error != nil { + return r.Error + } + fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected) + return nil + }) } diff --git a/internal/database/device.go b/internal/database/device.go new file mode 100644 index 0000000..554e0e9 --- /dev/null +++ b/internal/database/device.go @@ -0,0 +1,69 @@ +package database + +import ( + "context" + "fmt" + "github.com/ddworken/hishtory/shared" + "gorm.io/gorm" +) + +func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) { + var existingDevicesCount int64 + tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount) + if tx.Error != nil { + return 0, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return existingDevicesCount, nil +} + +func (db *DB) DevicesCount(ctx context.Context) (int64, error) { + var numDevices int64 = 0 + tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices) + if tx.Error != nil { + return 0, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return numDevices, nil +} + +func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error { + tx := db.WithContext(ctx).Create(device) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +} + +func (db *DB) DeviceEntriesCreateChunk(ctx context.Context, devices []*shared.Device, entries []*shared.EncHistoryEntry, chunkSize int) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + for _, device := range devices { + for _, entry := range entries { + entry.DeviceId = device.DeviceId + } + // Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error + for _, entriesChunk := range shared.Chunks(entries, chunkSize) { + resp := tx.Create(&entriesChunk) + if resp.Error != nil { + return fmt.Errorf("resp.Error: %w", resp.Error) + } + } + } + return nil + }) +} + +func (db *DB) DevicesForUser(ctx context.Context, userID string) ([]*shared.Device, error) { + var devices []*shared.Device + tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&devices) + if tx.Error != nil { + return nil, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return devices, nil +} + +func (db *DB) DeviceIncrementReadCounts(ctx context.Context, deviceID string) error { + return db.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceID).Error +} diff --git a/internal/database/enchistory.go b/internal/database/enchistory.go new file mode 100644 index 0000000..bc15f6b --- /dev/null +++ b/internal/database/enchistory.go @@ -0,0 +1,70 @@ +package database + +import ( + "context" + "fmt" + "github.com/ddworken/hishtory/shared" + "gorm.io/gorm" +) + +func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) { + var numDbEntries int64 + tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries) + if tx.Error != nil { + return 0, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return numDbEntries, nil +} + +func (db *DB) EncHistoryEntriesForUser(ctx context.Context, userID string) ([]*shared.EncHistoryEntry, error) { + var historyEntries []*shared.EncHistoryEntry + tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&historyEntries) + + if tx.Error != nil { + return nil, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return historyEntries, nil +} + +func (db *DB) EncHistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) { + var historyEntries []*shared.EncHistoryEntry + tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries) + + if tx.Error != nil { + return nil, fmt.Errorf("tx.Error: %w", tx.Error) + } + + return historyEntries, nil +} + +func (db *DB) EncHistoryCreate(ctx context.Context, entry *shared.EncHistoryEntry) error { + tx := db.WithContext(ctx).Create(entry) + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +} + +func (db *DB) EncHistoryCreateMulti(ctx context.Context, entries ...*shared.EncHistoryEntry) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + for _, entry := range entries { + resp := tx.Create(&entry) + if resp.Error != nil { + return fmt.Errorf("resp.Error: %w", resp.Error) + } + } + return nil + }) +} + +func (db *DB) EncHistoryClear(ctx context.Context) error { + tx := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries") + if tx.Error != nil { + return fmt.Errorf("tx.Error: %w", tx.Error) + } + + return nil +}