finishing removing direct DB instructions from http handlers

This commit is contained in:
Sergio Moura 2023-09-08 10:57:44 -04:00
parent 0d6aa081d8
commit 3c18f62d99
5 changed files with 374 additions and 207 deletions

View File

@ -166,26 +166,19 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
if len(entries) == 0 {
return
}
updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false)
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", entries[0].UserId)
var devices []*shared.Device
checkGormResult(tx.Find(&devices))
if err := updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId)
checkGormError(err, 0)
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
for _, device := range devices {
for _, entry := range entries {
entry.DeviceId = device.DeviceId
}
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
for _, entriesChunk := range shared.Chunks(entries, 1000) {
checkGormResult(tx.Create(&entriesChunk))
}
}
return nil
})
err = GLOBAL_DB.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000)
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
}
@ -201,15 +194,12 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
updateUsageData(r, userId, deviceId, 0, false)
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", userId)
var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries))
historyEntries, err := GLOBAL_DB.EncHistoryEntriesForUser(r.Context(), userId)
checkGormError(err, 1)
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
resp, err := json.Marshal(historyEntries)
if err != nil {
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
w.Write(resp)
}
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
@ -219,36 +209,31 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
updateUsageData(r, userId, deviceId, 0, true)
// Delete any entries that match a pending deletion request
var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
for _, request := range deletionRequests {
_, err := applyDeletionRequestsToBackend(r.Context(), *request)
if err != nil {
panic(err)
}
_, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request)
checkGormError(err, 0)
}
// Then retrieve
tx := GLOBAL_DB.WithContext(r.Context()).Where("device_id = ? AND read_count < 5", deviceId)
var historyEntries []*shared.EncHistoryEntry
checkGormResult(tx.Find(&historyEntries))
historyEntries, err := GLOBAL_DB.EncHistoryEntriesForDevice(r.Context(), deviceId, 5)
checkGormError(err, 0)
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
resp, err := json.Marshal(historyEntries)
if err != nil {
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
w.Write(resp)
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
// blocking the entire response. This does have a potential race condition, but that is fine.
if isProductionEnvironment() {
go func() {
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
err = incrementReadCounts(ctx, deviceId)
err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId)
span.Finish(tracer.WithError(err))
}()
} else {
err = incrementReadCounts(ctx, deviceId)
err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId)
if err != nil {
panic("failed to increment read counts")
}
@ -259,10 +244,6 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
}
}
func incrementReadCounts(ctx context.Context, deviceId string) error {
return GLOBAL_DB.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId).Error
}
func getRemoteAddr(r *http.Request) string {
addr, ok := r.Header["X-Real-Ip"]
if !ok || len(addr) == 0 {
@ -312,12 +293,12 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
respBody, err := json.Marshal(dumpRequests)
if err != nil {
dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
w.Write(respBody)
}
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
@ -328,26 +309,25 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
panic(err)
}
var entries []shared.EncHistoryEntry
var entries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &entries)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
// sanity check
for _, entry := range entries {
entry.DeviceId = requestingDeviceId
if entry.UserId != userId {
return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
}
checkGormResult(tx.Create(&entry))
}
return nil
})
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err))
}
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
err = GLOBAL_DB.EncHistoryCreateMulti(r.Context(), entries...)
checkGormError(err, 0)
err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
checkGormError(err, 0)
updateUsageData(r, userId, srcDeviceId, len(entries), false)
w.Header().Set("Content-Length", "0")
@ -371,16 +351,15 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId)
checkGormError(err, 0)
// Return all the deletion requests
var deletionRequests []*shared.DeletionRequest
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
respBody, err := json.Marshal(deletionRequests)
if err != nil {
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
w.Write(respBody)
}
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
@ -389,32 +368,15 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
panic(err)
}
var request shared.DeletionRequest
err = json.Unmarshal(data, &request)
if err != nil {
if err := json.Unmarshal(data, &request); err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
request.ReadCount = 0
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
// Store the deletion request so all the devices will get it
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", request.UserId)
var devices []*shared.Device
checkGormResult(tx.Find(&devices))
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId))
}
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
for _, device := range devices {
request.DestinationDeviceId = device.DeviceId
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&request))
}
// Also delete anything currently in the DB matching it
numDeleted, err := applyDeletionRequestsToBackend(r.Context(), request)
if err != nil {
panic(err)
}
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request)
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
@ -423,21 +385,27 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
if isProductionEnvironment() {
// Check that we have a reasonable looking set of devices/entries in the DB
rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
if err != nil {
panic(fmt.Sprintf("failed to count entries in DB: %v", err))
}
defer rows.Close()
if !rows.Next() {
//rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
//if err != nil {
// panic(fmt.Sprintf("failed to count entries in DB: %v", err))
//}
//defer rows.Close()
//if !rows.Next() {
// panic("Suspiciously few enc history entries!")
//}
encHistoryEntryCount, err := GLOBAL_DB.EncHistoryEntryCount(r.Context())
checkGormError(err, 0)
if encHistoryEntryCount < 1000 {
panic("Suspiciously few enc history entries!")
}
var count int64
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&count))
if count < 100 {
deviceCount, err := GLOBAL_DB.DevicesCount(r.Context())
checkGormError(err, 0)
if deviceCount < 100 {
panic("Suspiciously few devices!")
}
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.EncHistoryEntry{
err = GLOBAL_DB.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{
EncryptedData: []byte("data"),
Nonce: []byte("nonce"),
DeviceId: "healthcheck_device_id",
@ -445,7 +413,8 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
Date: time.Now(),
EncryptedId: "healthcheck_enc_id",
ReadCount: 10000,
}))
})
checkGormError(err, 0)
} else {
err := GLOBAL_DB.Ping()
if err != nil {
@ -455,16 +424,6 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
}
func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) {
tx := GLOBAL_DB.WithContext(ctx).Where("false")
for _, message := range request.Messages.Ids {
tx = tx.Or(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
}
result := tx.Delete(&shared.EncHistoryEntry{})
checkGormResult(result)
return int(result.RowsAffected), nil
}
func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
panic("refusing to wipe the DB for prod")
@ -472,7 +431,9 @@ func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
if !isTestEnvironment() {
panic("refusing to wipe the DB non-test environment")
}
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries"))
err := GLOBAL_DB.EncHistoryClear(r.Context())
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
@ -562,7 +523,7 @@ func cron(ctx context.Context) error {
if err != nil {
panic(err)
}
err = cleanDatabase(ctx)
err = GLOBAL_DB.Clean(ctx)
if err != nil {
panic(err)
}
@ -669,21 +630,21 @@ func InitDB() {
var err error
GLOBAL_DB, err = OpenDB()
if err != nil {
panic(err)
}
sqlDb, err := GLOBAL_DB.DB.DB()
if err != nil {
panic(err)
panic(fmt.Errorf("OpenDB: %w", err))
}
if err := GLOBAL_DB.Ping(); err != nil {
panic(fmt.Errorf("ping: %w", err))
}
if isProductionEnvironment() {
sqlDb.SetMaxIdleConns(10)
if err := GLOBAL_DB.SetMaxIdleConns(10); err != nil {
panic(fmt.Errorf("failed to set max idle conns: %w", err))
}
}
if isTestEnvironment() {
sqlDb.SetMaxIdleConns(1)
if err := GLOBAL_DB.SetMaxIdleConns(1); err != nil {
panic(fmt.Errorf("failed to set max idle conns: %w", err))
}
}
}
@ -759,7 +720,8 @@ func feedbackHandler(w http.ResponseWriter, r *http.Request) {
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
}
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback))
err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback)
checkGormError(err, 0)
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
@ -834,58 +796,6 @@ func byteCountToString(b int) string {
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
}
func cleanDatabase(ctx context.Context) error {
r := GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
if r.Error != nil {
return r.Error
}
r = GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100")
if r.Error != nil {
return r.Error
}
return nil
}
func deepCleanDatabase(ctx context.Context) {
err := GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
r := tx.Exec(`
CREATE TEMP TABLE temp_users_with_one_device AS (
SELECT user_id
FROM devices
GROUP BY user_id
HAVING COUNT(DISTINCT device_id) > 1
)
`)
if r.Error != nil {
return r.Error
}
r = tx.Exec(`
CREATE TEMP TABLE temp_inactive_users AS (
SELECT user_id
FROM usage_data
WHERE last_used <= (now() - INTERVAL '90 days')
)
`)
if r.Error != nil {
return r.Error
}
r = tx.Exec(`
SELECT COUNT(*) FROM enc_history_entries WHERE
date <= (now() - INTERVAL '90 days')
AND user_id IN (SELECT * FROM temp_users_with_one_device)
AND user_id IN (SELECT * FROM temp_inactive_users)
`)
if r.Error != nil {
return r.Error
}
fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected)
return nil
})
if err != nil {
panic(fmt.Errorf("failed to deep clean DB: %w", err))
}
}
func configureObservability(mux *httptrace.ServeMux) func() {
// Profiler
err := profiler.Start(
@ -933,7 +843,11 @@ func main() {
if isProductionEnvironment() {
defer configureObservability(mux)()
go deepCleanDatabase(context.Background())
go func() {
if err := GLOBAL_DB.DeepClean(context.Background()); err != nil {
panic(err)
}
}()
}
mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler))

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"github.com/ddworken/hishtory/internal/database"
"io"
"net/http"
"net/http/httptest"
@ -17,7 +18,6 @@ import (
"github.com/ddworken/hishtory/shared/testutils"
"github.com/go-test/deep"
"github.com/google/uuid"
"gorm.io/gorm"
)
func TestESubmitThenQuery(t *testing.T) {
@ -564,15 +564,15 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
apiSubmitHandler(httptest.NewRecorder(), submitReq)
// Call cleanDatabase and just check that there are no panics
testutils.Check(t, cleanDatabase(context.TODO()))
testutils.Check(t, GLOBAL_DB.Clean(context.TODO()))
}
func assertNoLeakedConnections(t *testing.T, db *gorm.DB) {
sqlDB, err := db.DB()
func assertNoLeakedConnections(t *testing.T, db *database.DB) {
stats, err := db.Stats()
if err != nil {
t.Fatal(err)
}
numConns := sqlDB.Stats().OpenConnections
numConns := stats.OpenConnections
if numConns > 1 {
t.Fatalf("expected DB to have not leak connections, actually have %d", numConns)
}

View File

@ -86,6 +86,17 @@ func (db *DB) Ping() error {
return nil
}
func (db *DB) SetMaxIdleConns(n int) error {
rawDB, err := db.DB.DB()
if err != nil {
return err
}
rawDB.SetMaxIdleConns(n)
return nil
}
func (db *DB) Stats() (sql.DBStats, error) {
rawDB, err := db.DB.DB()
if err != nil {
@ -106,35 +117,6 @@ func (db *DB) DistinctUsers(ctx context.Context) (int64, error) {
return numDistinctUsers, nil
}
func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) {
var existingDevicesCount int64
tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return existingDevicesCount, nil
}
func (db *DB) DevicesCount(ctx context.Context) (int64, error) {
var numDevices int64 = 0
tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return numDevices, nil
}
func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error {
tx := db.WithContext(ctx).Create(device)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) error {
tx := db.WithContext(ctx).Create(req)
if tx.Error != nil {
@ -144,12 +126,144 @@ func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) er
return nil
}
func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) {
var numDbEntries int64
tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)
func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DumpRequest, error) {
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
tx := db.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userID, deviceID).Find(&dumpRequests)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}
return dumpRequests, nil
}
func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, deviceID string) error {
tx := db.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userID, deviceID)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, request *shared.DeletionRequest) (int64, error) {
tx := db.WithContext(ctx).Where("false")
for _, message := range request.Messages.Ids {
tx = tx.Or(db.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
}
result := tx.Delete(&shared.EncHistoryEntry{})
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return numDbEntries, nil
return result.RowsAffected, nil
}
func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) error {
tx := db.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE user_id = ? AND destination_device_id = ?", userID, deviceID)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DeletionRequest, error) {
var deletionRequests []*shared.DeletionRequest
tx := db.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userID, deviceID).Find(&deletionRequests)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}
return deletionRequests, nil
}
func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.DeletionRequest) error {
userID := request.UserId
devices, err := db.DevicesForUser(ctx, userID)
if err != nil {
return fmt.Errorf("db.DevicesForUser: %w", err)
}
if len(devices) == 0 {
return fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", userID)
}
fmt.Printf("db.DeletionRequestCreate: Found %d devices\n", len(devices))
// TODO: maybe this should be a transaction?
for _, device := range devices {
request.DestinationDeviceId = device.DeviceId
tx := db.WithContext(ctx).Create(&request)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
}
numDeleted, err := db.ApplyDeletionRequestsToBackend(ctx, request)
if err != nil {
return fmt.Errorf("db.ApplyDeletionRequestsToBackend: %w", err)
}
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
return nil
}
func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) error {
tx := db.WithContext(ctx).Create(feedback)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) Clean(ctx context.Context) error {
r := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
if r.Error != nil {
return r.Error
}
r = db.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100")
if r.Error != nil {
return r.Error
}
return nil
}
func (db *DB) DeepClean(ctx context.Context) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
r := tx.Exec(`
CREATE TEMP TABLE temp_users_with_one_device AS (
SELECT user_id
FROM devices
GROUP BY user_id
HAVING COUNT(DISTINCT device_id) > 1
)
`)
if r.Error != nil {
return r.Error
}
r = tx.Exec(`
CREATE TEMP TABLE temp_inactive_users AS (
SELECT user_id
FROM usage_data
WHERE last_used <= (now() - INTERVAL '90 days')
)
`)
if r.Error != nil {
return r.Error
}
r = tx.Exec(`
SELECT COUNT(*) FROM enc_history_entries WHERE
date <= (now() - INTERVAL '90 days')
AND user_id IN (SELECT * FROM temp_users_with_one_device)
AND user_id IN (SELECT * FROM temp_inactive_users)
`)
if r.Error != nil {
return r.Error
}
fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected)
return nil
})
}

View File

@ -0,0 +1,69 @@
package database
import (
"context"
"fmt"
"github.com/ddworken/hishtory/shared"
"gorm.io/gorm"
)
func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) {
var existingDevicesCount int64
tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return existingDevicesCount, nil
}
func (db *DB) DevicesCount(ctx context.Context) (int64, error) {
var numDevices int64 = 0
tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return numDevices, nil
}
func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error {
tx := db.WithContext(ctx).Create(device)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) DeviceEntriesCreateChunk(ctx context.Context, devices []*shared.Device, entries []*shared.EncHistoryEntry, chunkSize int) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, device := range devices {
for _, entry := range entries {
entry.DeviceId = device.DeviceId
}
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
for _, entriesChunk := range shared.Chunks(entries, chunkSize) {
resp := tx.Create(&entriesChunk)
if resp.Error != nil {
return fmt.Errorf("resp.Error: %w", resp.Error)
}
}
}
return nil
})
}
func (db *DB) DevicesForUser(ctx context.Context, userID string) ([]*shared.Device, error) {
var devices []*shared.Device
tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&devices)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}
return devices, nil
}
func (db *DB) DeviceIncrementReadCounts(ctx context.Context, deviceID string) error {
return db.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceID).Error
}

View File

@ -0,0 +1,70 @@
package database
import (
"context"
"fmt"
"github.com/ddworken/hishtory/shared"
"gorm.io/gorm"
)
func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) {
var numDbEntries int64
tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)
if tx.Error != nil {
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
}
return numDbEntries, nil
}
func (db *DB) EncHistoryEntriesForUser(ctx context.Context, userID string) ([]*shared.EncHistoryEntry, error) {
var historyEntries []*shared.EncHistoryEntry
tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&historyEntries)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}
return historyEntries, nil
}
func (db *DB) EncHistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
var historyEntries []*shared.EncHistoryEntry
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries)
if tx.Error != nil {
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
}
return historyEntries, nil
}
func (db *DB) EncHistoryCreate(ctx context.Context, entry *shared.EncHistoryEntry) error {
tx := db.WithContext(ctx).Create(entry)
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}
func (db *DB) EncHistoryCreateMulti(ctx context.Context, entries ...*shared.EncHistoryEntry) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, entry := range entries {
resp := tx.Create(&entry)
if resp.Error != nil {
return fmt.Errorf("resp.Error: %w", resp.Error)
}
}
return nil
})
}
func (db *DB) EncHistoryClear(ctx context.Context) error {
tx := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries")
if tx.Error != nil {
return fmt.Errorf("tx.Error: %w", tx.Error)
}
return nil
}