mirror of
https://github.com/ddworken/hishtory.git
synced 2025-06-20 03:47:54 +02:00
finishing removing direct DB instructions from http handlers
This commit is contained in:
parent
0d6aa081d8
commit
3c18f62d99
@ -166,26 +166,19 @@ func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false)
|
||||
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", entries[0].UserId)
|
||||
var devices []*shared.Device
|
||||
checkGormResult(tx.Find(&devices))
|
||||
if err := updateUsageData(r, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil {
|
||||
fmt.Printf("updateUsageData: %v\n", err)
|
||||
}
|
||||
|
||||
devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId)
|
||||
checkGormError(err, 0)
|
||||
|
||||
if len(devices) == 0 {
|
||||
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
|
||||
}
|
||||
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
|
||||
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
|
||||
for _, device := range devices {
|
||||
for _, entry := range entries {
|
||||
entry.DeviceId = device.DeviceId
|
||||
}
|
||||
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
|
||||
for _, entriesChunk := range shared.Chunks(entries, 1000) {
|
||||
checkGormResult(tx.Create(&entriesChunk))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
err = GLOBAL_DB.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
|
||||
}
|
||||
@ -201,15 +194,12 @@ func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userId := getRequiredQueryParam(r, "user_id")
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
updateUsageData(r, userId, deviceId, 0, false)
|
||||
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", userId)
|
||||
var historyEntries []*shared.EncHistoryEntry
|
||||
checkGormResult(tx.Find(&historyEntries))
|
||||
historyEntries, err := GLOBAL_DB.EncHistoryEntriesForUser(r.Context(), userId)
|
||||
checkGormError(err, 1)
|
||||
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
|
||||
resp, err := json.Marshal(historyEntries)
|
||||
if err != nil {
|
||||
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Write(resp)
|
||||
}
|
||||
|
||||
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -219,36 +209,31 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
updateUsageData(r, userId, deviceId, 0, true)
|
||||
|
||||
// Delete any entries that match a pending deletion request
|
||||
var deletionRequests []*shared.DeletionRequest
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("destination_device_id = ? AND user_id = ?", deviceId, userId).Find(&deletionRequests))
|
||||
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
||||
checkGormError(err, 0)
|
||||
for _, request := range deletionRequests {
|
||||
_, err := applyDeletionRequestsToBackend(r.Context(), *request)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request)
|
||||
checkGormError(err, 0)
|
||||
}
|
||||
|
||||
// Then retrieve
|
||||
tx := GLOBAL_DB.WithContext(r.Context()).Where("device_id = ? AND read_count < 5", deviceId)
|
||||
var historyEntries []*shared.EncHistoryEntry
|
||||
checkGormResult(tx.Find(&historyEntries))
|
||||
historyEntries, err := GLOBAL_DB.EncHistoryEntriesForDevice(r.Context(), deviceId, 5)
|
||||
checkGormError(err, 0)
|
||||
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
|
||||
resp, err := json.Marshal(historyEntries)
|
||||
if err != nil {
|
||||
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Write(resp)
|
||||
|
||||
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
|
||||
// blocking the entire response. This does have a potential race condition, but that is fine.
|
||||
if isProductionEnvironment() {
|
||||
go func() {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
|
||||
err = incrementReadCounts(ctx, deviceId)
|
||||
err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId)
|
||||
span.Finish(tracer.WithError(err))
|
||||
}()
|
||||
} else {
|
||||
err = incrementReadCounts(ctx, deviceId)
|
||||
err := GLOBAL_DB.DeviceIncrementReadCounts(ctx, deviceId)
|
||||
if err != nil {
|
||||
panic("failed to increment read counts")
|
||||
}
|
||||
@ -259,10 +244,6 @@ func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func incrementReadCounts(ctx context.Context, deviceId string) error {
|
||||
return GLOBAL_DB.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceId).Error
|
||||
}
|
||||
|
||||
func getRemoteAddr(r *http.Request) string {
|
||||
addr, ok := r.Header["X-Real-Ip"]
|
||||
if !ok || len(addr) == 0 {
|
||||
@ -312,12 +293,12 @@ func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
var dumpRequests []*shared.DumpRequest
|
||||
// Filter out ones requested by the hishtory instance that sent this request
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND requesting_device_id != ?", userId, deviceId).Find(&dumpRequests))
|
||||
respBody, err := json.Marshal(dumpRequests)
|
||||
if err != nil {
|
||||
dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
|
||||
checkGormError(err, 0)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
|
||||
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
||||
}
|
||||
w.Write(respBody)
|
||||
}
|
||||
|
||||
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -328,26 +309,25 @@ func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var entries []shared.EncHistoryEntry
|
||||
var entries []*shared.EncHistoryEntry
|
||||
err = json.Unmarshal(data, &entries)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
||||
}
|
||||
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
|
||||
err = GLOBAL_DB.WithContext(r.Context()).Transaction(func(tx *gorm.DB) error {
|
||||
for _, entry := range entries {
|
||||
entry.DeviceId = requestingDeviceId
|
||||
if entry.UserId != userId {
|
||||
return fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId)
|
||||
}
|
||||
checkGormResult(tx.Create(&entry))
|
||||
|
||||
// sanity check
|
||||
for _, entry := range entries {
|
||||
entry.DeviceId = requestingDeviceId
|
||||
if entry.UserId != userId {
|
||||
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to execute transaction to add dumped DB: %w", err))
|
||||
}
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userId, requestingDeviceId))
|
||||
|
||||
err = GLOBAL_DB.EncHistoryCreateMulti(r.Context(), entries...)
|
||||
checkGormError(err, 0)
|
||||
err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
|
||||
checkGormError(err, 0)
|
||||
updateUsageData(r, userId, srcDeviceId, len(entries), false)
|
||||
|
||||
w.Header().Set("Content-Length", "0")
|
||||
@ -371,16 +351,15 @@ func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
deviceId := getRequiredQueryParam(r, "device_id")
|
||||
|
||||
// Increment the ReadCount
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE destination_device_id = ? AND user_id = ?", deviceId, userId))
|
||||
err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId)
|
||||
checkGormError(err, 0)
|
||||
|
||||
// Return all the deletion requests
|
||||
var deletionRequests []*shared.DeletionRequest
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND destination_device_id = ?", userId, deviceId).Find(&deletionRequests))
|
||||
respBody, err := json.Marshal(deletionRequests)
|
||||
if err != nil {
|
||||
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
||||
checkGormError(err, 0)
|
||||
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
|
||||
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
||||
}
|
||||
w.Write(respBody)
|
||||
}
|
||||
|
||||
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -389,32 +368,15 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
panic(err)
|
||||
}
|
||||
var request shared.DeletionRequest
|
||||
err = json.Unmarshal(data, &request)
|
||||
if err != nil {
|
||||
|
||||
if err := json.Unmarshal(data, &request); err != nil {
|
||||
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
||||
}
|
||||
request.ReadCount = 0
|
||||
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
|
||||
|
||||
// Store the deletion request so all the devices will get it
|
||||
tx := GLOBAL_DB.WithContext(r.Context()).Where("user_id = ?", request.UserId)
|
||||
var devices []*shared.Device
|
||||
checkGormResult(tx.Find(&devices))
|
||||
if len(devices) == 0 {
|
||||
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", request.UserId))
|
||||
}
|
||||
fmt.Printf("addDeletionRequestHandler: Found %d devices\n", len(devices))
|
||||
for _, device := range devices {
|
||||
request.DestinationDeviceId = device.DeviceId
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&request))
|
||||
}
|
||||
|
||||
// Also delete anything currently in the DB matching it
|
||||
numDeleted, err := applyDeletionRequestsToBackend(r.Context(), request)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
|
||||
err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request)
|
||||
checkGormError(err, 0)
|
||||
|
||||
w.Header().Set("Content-Length", "0")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@ -423,21 +385,27 @@ func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if isProductionEnvironment() {
|
||||
// Check that we have a reasonable looking set of devices/entries in the DB
|
||||
rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to count entries in DB: %v", err))
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
//rows, err := GLOBAL_DB.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
|
||||
//if err != nil {
|
||||
// panic(fmt.Sprintf("failed to count entries in DB: %v", err))
|
||||
//}
|
||||
//defer rows.Close()
|
||||
//if !rows.Next() {
|
||||
// panic("Suspiciously few enc history entries!")
|
||||
//}
|
||||
encHistoryEntryCount, err := GLOBAL_DB.EncHistoryEntryCount(r.Context())
|
||||
checkGormError(err, 0)
|
||||
if encHistoryEntryCount < 1000 {
|
||||
panic("Suspiciously few enc history entries!")
|
||||
}
|
||||
var count int64
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.Device{}).Count(&count))
|
||||
if count < 100 {
|
||||
|
||||
deviceCount, err := GLOBAL_DB.DevicesCount(r.Context())
|
||||
checkGormError(err, 0)
|
||||
if deviceCount < 100 {
|
||||
panic("Suspiciously few devices!")
|
||||
}
|
||||
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(&shared.EncHistoryEntry{
|
||||
err = GLOBAL_DB.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{
|
||||
EncryptedData: []byte("data"),
|
||||
Nonce: []byte("nonce"),
|
||||
DeviceId: "healthcheck_device_id",
|
||||
@ -445,7 +413,8 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Date: time.Now(),
|
||||
EncryptedId: "healthcheck_enc_id",
|
||||
ReadCount: 10000,
|
||||
}))
|
||||
})
|
||||
checkGormError(err, 0)
|
||||
} else {
|
||||
err := GLOBAL_DB.Ping()
|
||||
if err != nil {
|
||||
@ -455,16 +424,6 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
|
||||
func applyDeletionRequestsToBackend(ctx context.Context, request shared.DeletionRequest) (int, error) {
|
||||
tx := GLOBAL_DB.WithContext(ctx).Where("false")
|
||||
for _, message := range request.Messages.Ids {
|
||||
tx = tx.Or(GLOBAL_DB.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
|
||||
}
|
||||
result := tx.Delete(&shared.EncHistoryEntry{})
|
||||
checkGormResult(result)
|
||||
return int(result.RowsAffected), nil
|
||||
}
|
||||
|
||||
func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
|
||||
panic("refusing to wipe the DB for prod")
|
||||
@ -472,7 +431,9 @@ func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if !isTestEnvironment() {
|
||||
panic("refusing to wipe the DB non-test environment")
|
||||
}
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Exec("DELETE FROM enc_history_entries"))
|
||||
|
||||
err := GLOBAL_DB.EncHistoryClear(r.Context())
|
||||
checkGormError(err, 0)
|
||||
|
||||
w.Header().Set("Content-Length", "0")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@ -562,7 +523,7 @@ func cron(ctx context.Context) error {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = cleanDatabase(ctx)
|
||||
err = GLOBAL_DB.Clean(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -669,21 +630,21 @@ func InitDB() {
|
||||
var err error
|
||||
GLOBAL_DB, err = OpenDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sqlDb, err := GLOBAL_DB.DB.DB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
panic(fmt.Errorf("OpenDB: %w", err))
|
||||
}
|
||||
|
||||
if err := GLOBAL_DB.Ping(); err != nil {
|
||||
panic(fmt.Errorf("ping: %w", err))
|
||||
}
|
||||
if isProductionEnvironment() {
|
||||
sqlDb.SetMaxIdleConns(10)
|
||||
if err := GLOBAL_DB.SetMaxIdleConns(10); err != nil {
|
||||
panic(fmt.Errorf("failed to set max idle conns: %w", err))
|
||||
}
|
||||
}
|
||||
if isTestEnvironment() {
|
||||
sqlDb.SetMaxIdleConns(1)
|
||||
if err := GLOBAL_DB.SetMaxIdleConns(1); err != nil {
|
||||
panic(fmt.Errorf("failed to set max idle conns: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -759,7 +720,8 @@ func feedbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
|
||||
}
|
||||
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
|
||||
checkGormResult(GLOBAL_DB.WithContext(r.Context()).Create(feedback))
|
||||
err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback)
|
||||
checkGormError(err, 0)
|
||||
|
||||
if GLOBAL_STATSD != nil {
|
||||
GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
|
||||
@ -834,58 +796,6 @@ func byteCountToString(b int) string {
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
|
||||
}
|
||||
|
||||
func cleanDatabase(ctx context.Context) error {
|
||||
r := GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = GLOBAL_DB.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100")
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deepCleanDatabase(ctx context.Context) {
|
||||
err := GLOBAL_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
r := tx.Exec(`
|
||||
CREATE TEMP TABLE temp_users_with_one_device AS (
|
||||
SELECT user_id
|
||||
FROM devices
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT device_id) > 1
|
||||
)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = tx.Exec(`
|
||||
CREATE TEMP TABLE temp_inactive_users AS (
|
||||
SELECT user_id
|
||||
FROM usage_data
|
||||
WHERE last_used <= (now() - INTERVAL '90 days')
|
||||
)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = tx.Exec(`
|
||||
SELECT COUNT(*) FROM enc_history_entries WHERE
|
||||
date <= (now() - INTERVAL '90 days')
|
||||
AND user_id IN (SELECT * FROM temp_users_with_one_device)
|
||||
AND user_id IN (SELECT * FROM temp_inactive_users)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to deep clean DB: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func configureObservability(mux *httptrace.ServeMux) func() {
|
||||
// Profiler
|
||||
err := profiler.Start(
|
||||
@ -933,7 +843,11 @@ func main() {
|
||||
|
||||
if isProductionEnvironment() {
|
||||
defer configureObservability(mux)()
|
||||
go deepCleanDatabase(context.Background())
|
||||
go func() {
|
||||
if err := GLOBAL_DB.DeepClean(context.Background()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler))
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/ddworken/hishtory/internal/database"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -17,7 +18,6 @@ import (
|
||||
"github.com/ddworken/hishtory/shared/testutils"
|
||||
"github.com/go-test/deep"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestESubmitThenQuery(t *testing.T) {
|
||||
@ -564,15 +564,15 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
|
||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||
|
||||
// Call cleanDatabase and just check that there are no panics
|
||||
testutils.Check(t, cleanDatabase(context.TODO()))
|
||||
testutils.Check(t, GLOBAL_DB.Clean(context.TODO()))
|
||||
}
|
||||
|
||||
func assertNoLeakedConnections(t *testing.T, db *gorm.DB) {
|
||||
sqlDB, err := db.DB()
|
||||
func assertNoLeakedConnections(t *testing.T, db *database.DB) {
|
||||
stats, err := db.Stats()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
numConns := sqlDB.Stats().OpenConnections
|
||||
numConns := stats.OpenConnections
|
||||
if numConns > 1 {
|
||||
t.Fatalf("expected DB to have not leak connections, actually have %d", numConns)
|
||||
}
|
||||
|
@ -86,6 +86,17 @@ func (db *DB) Ping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) SetMaxIdleConns(n int) error {
|
||||
rawDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawDB.SetMaxIdleConns(n)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Stats() (sql.DBStats, error) {
|
||||
rawDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
@ -106,35 +117,6 @@ func (db *DB) DistinctUsers(ctx context.Context) (int64, error) {
|
||||
return numDistinctUsers, nil
|
||||
}
|
||||
|
||||
func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) {
|
||||
var existingDevicesCount int64
|
||||
tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount)
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return existingDevicesCount, nil
|
||||
}
|
||||
|
||||
func (db *DB) DevicesCount(ctx context.Context) (int64, error) {
|
||||
var numDevices int64 = 0
|
||||
tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return numDevices, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error {
|
||||
tx := db.WithContext(ctx).Create(device)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) error {
|
||||
tx := db.WithContext(ctx).Create(req)
|
||||
if tx.Error != nil {
|
||||
@ -144,12 +126,144 @@ func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) {
|
||||
var numDbEntries int64
|
||||
tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)
|
||||
func (db *DB) DumpRequestForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DumpRequest, error) {
|
||||
var dumpRequests []*shared.DumpRequest
|
||||
// Filter out ones requested by the hishtory instance that sent this request
|
||||
tx := db.WithContext(ctx).Where("user_id = ? AND requesting_device_id != ?", userID, deviceID).Find(&dumpRequests)
|
||||
if tx.Error != nil {
|
||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return dumpRequests, nil
|
||||
}
|
||||
|
||||
func (db *DB) DumpRequestDeleteForUserAndDevice(ctx context.Context, userID, deviceID string) error {
|
||||
tx := db.WithContext(ctx).Delete(&shared.DumpRequest{}, "user_id = ? AND requesting_device_id = ?", userID, deviceID)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ApplyDeletionRequestsToBackend(ctx context.Context, request *shared.DeletionRequest) (int64, error) {
|
||||
tx := db.WithContext(ctx).Where("false")
|
||||
for _, message := range request.Messages.Ids {
|
||||
tx = tx.Or(db.WithContext(ctx).Where("user_id = ? AND device_id = ? AND date = ?", request.UserId, message.DeviceId, message.Date))
|
||||
}
|
||||
result := tx.Delete(&shared.EncHistoryEntry{})
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return numDbEntries, nil
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeletionRequestInc(ctx context.Context, userID, deviceID string) error {
|
||||
tx := db.WithContext(ctx).Exec("UPDATE deletion_requests SET read_count = read_count + 1 WHERE user_id = ? AND destination_device_id = ?", userID, deviceID)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DeletionRequestsForUserAndDevice(ctx context.Context, userID, deviceID string) ([]*shared.DeletionRequest, error) {
|
||||
var deletionRequests []*shared.DeletionRequest
|
||||
tx := db.WithContext(ctx).Where("user_id = ? AND destination_device_id = ?", userID, deviceID).Find(&deletionRequests)
|
||||
if tx.Error != nil {
|
||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return deletionRequests, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeletionRequestCreate(ctx context.Context, request *shared.DeletionRequest) error {
|
||||
userID := request.UserId
|
||||
|
||||
devices, err := db.DevicesForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.DevicesForUser: %w", err)
|
||||
}
|
||||
|
||||
if len(devices) == 0 {
|
||||
return fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", userID)
|
||||
}
|
||||
|
||||
fmt.Printf("db.DeletionRequestCreate: Found %d devices\n", len(devices))
|
||||
|
||||
// TODO: maybe this should be a transaction?
|
||||
for _, device := range devices {
|
||||
request.DestinationDeviceId = device.DeviceId
|
||||
tx := db.WithContext(ctx).Create(&request)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
}
|
||||
|
||||
numDeleted, err := db.ApplyDeletionRequestsToBackend(ctx, request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db.ApplyDeletionRequestsToBackend: %w", err)
|
||||
}
|
||||
fmt.Printf("addDeletionRequestHandler: Deleted %d rows in the backend\n", numDeleted)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) FeedbackCreate(ctx context.Context, feedback *shared.Feedback) error {
|
||||
tx := db.WithContext(ctx).Create(feedback)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Clean(ctx context.Context) error {
|
||||
r := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries WHERE read_count > 10")
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = db.WithContext(ctx).Exec("DELETE FROM deletion_requests WHERE read_count > 100")
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DeepClean(ctx context.Context) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
r := tx.Exec(`
|
||||
CREATE TEMP TABLE temp_users_with_one_device AS (
|
||||
SELECT user_id
|
||||
FROM devices
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT device_id) > 1
|
||||
)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = tx.Exec(`
|
||||
CREATE TEMP TABLE temp_inactive_users AS (
|
||||
SELECT user_id
|
||||
FROM usage_data
|
||||
WHERE last_used <= (now() - INTERVAL '90 days')
|
||||
)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
r = tx.Exec(`
|
||||
SELECT COUNT(*) FROM enc_history_entries WHERE
|
||||
date <= (now() - INTERVAL '90 days')
|
||||
AND user_id IN (SELECT * FROM temp_users_with_one_device)
|
||||
AND user_id IN (SELECT * FROM temp_inactive_users)
|
||||
`)
|
||||
if r.Error != nil {
|
||||
return r.Error
|
||||
}
|
||||
fmt.Printf("Ran deep clean and deleted %d rows\n", r.RowsAffected)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
69
internal/database/device.go
Normal file
69
internal/database/device.go
Normal file
@ -0,0 +1,69 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (db *DB) DevicesCountForUser(ctx context.Context, userID string) (int64, error) {
|
||||
var existingDevicesCount int64
|
||||
tx := db.WithContext(ctx).Model(&shared.Device{}).Where("user_id = ?", userID).Count(&existingDevicesCount)
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return existingDevicesCount, nil
|
||||
}
|
||||
|
||||
func (db *DB) DevicesCount(ctx context.Context) (int64, error) {
|
||||
var numDevices int64 = 0
|
||||
tx := db.WithContext(ctx).Model(&shared.Device{}).Count(&numDevices)
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return numDevices, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeviceCreate(ctx context.Context, device *shared.Device) error {
|
||||
tx := db.WithContext(ctx).Create(device)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) DeviceEntriesCreateChunk(ctx context.Context, devices []*shared.Device, entries []*shared.EncHistoryEntry, chunkSize int) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, device := range devices {
|
||||
for _, entry := range entries {
|
||||
entry.DeviceId = device.DeviceId
|
||||
}
|
||||
// Chunk the inserts to prevent the `extended protocol limited to 65535 parameters` error
|
||||
for _, entriesChunk := range shared.Chunks(entries, chunkSize) {
|
||||
resp := tx.Create(&entriesChunk)
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("resp.Error: %w", resp.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) DevicesForUser(ctx context.Context, userID string) ([]*shared.Device, error) {
|
||||
var devices []*shared.Device
|
||||
tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&devices)
|
||||
if tx.Error != nil {
|
||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (db *DB) DeviceIncrementReadCounts(ctx context.Context, deviceID string) error {
|
||||
return db.WithContext(ctx).Exec("UPDATE enc_history_entries SET read_count = read_count + 1 WHERE device_id = ?", deviceID).Error
|
||||
}
|
70
internal/database/enchistory.go
Normal file
70
internal/database/enchistory.go
Normal file
@ -0,0 +1,70 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/ddworken/hishtory/shared"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (db *DB) EncHistoryEntryCount(ctx context.Context) (int64, error) {
|
||||
var numDbEntries int64
|
||||
tx := db.WithContext(ctx).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)
|
||||
if tx.Error != nil {
|
||||
return 0, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return numDbEntries, nil
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryEntriesForUser(ctx context.Context, userID string) ([]*shared.EncHistoryEntry, error) {
|
||||
var historyEntries []*shared.EncHistoryEntry
|
||||
tx := db.WithContext(ctx).Where("user_id = ?", userID).Find(&historyEntries)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return historyEntries, nil
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryEntriesForDevice(ctx context.Context, deviceID string, limit int) ([]*shared.EncHistoryEntry, error) {
|
||||
var historyEntries []*shared.EncHistoryEntry
|
||||
tx := db.WithContext(ctx).Where("device_id = ? AND read_count < ?", deviceID, limit).Find(&historyEntries)
|
||||
|
||||
if tx.Error != nil {
|
||||
return nil, fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return historyEntries, nil
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryCreate(ctx context.Context, entry *shared.EncHistoryEntry) error {
|
||||
tx := db.WithContext(ctx).Create(entry)
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryCreateMulti(ctx context.Context, entries ...*shared.EncHistoryEntry) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, entry := range entries {
|
||||
resp := tx.Create(&entry)
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("resp.Error: %w", resp.Error)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) EncHistoryClear(ctx context.Context) error {
|
||||
tx := db.WithContext(ctx).Exec("DELETE FROM enc_history_entries")
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("tx.Error: %w", tx.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user