diff --git a/backend/server/server.go b/backend/server/server.go index 499d393..2b7695b 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -19,18 +19,14 @@ import ( pprofhttp "net/http/pprof" "github.com/DataDog/datadog-go/statsd" + "github.com/ddworken/hishtory/internal/database" "github.com/ddworken/hishtory/shared" - "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" "github.com/rodaine/table" - sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" - gormtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorm.io/gorm.v1" httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/DataDog/dd-trace-go.v1/profiler" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) @@ -40,22 +36,11 @@ const ( ) var ( - GLOBAL_DB *gorm.DB + GLOBAL_DB *database.DB GLOBAL_STATSD *statsd.Client ReleaseVersion string = "UNKNOWN" ) -type UsageData struct { - UserId string `json:"user_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` - DeviceId string `json:"device_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` - LastUsed time.Time `json:"last_used"` - LastIp string `json:"last_ip"` - NumEntriesHandled int `json:"num_entries_handled"` - LastQueried time.Time `json:"last_queried"` - NumQueries int `json:"num_queries"` - Version string `json:"version"` -} - func getRequiredQueryParam(r *http.Request, queryParam string) string { val := r.URL.Query().Get(queryParam) if val == "" { @@ -68,65 +53,73 @@ func getHishtoryVersion(r *http.Request) string { return r.Header.Get("X-Hishtory-Version") } -func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) { - var usageData []UsageData - GLOBAL_DB.WithContext(r.Context()).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData) +func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) error { + var usageData []shared.UsageData + usageData, err := GLOBAL_DB.UsageDataFindByUserAndDevice(r.Context(), userId, deviceId) + if err != nil { + return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err) + } if len(usageData) == 0 { - GLOBAL_DB.WithContext(r.Context()).Create(&UsageData{UserId: userId, DeviceId: deviceId, LastUsed: time.Now(), NumEntriesHandled: numEntriesHandled, Version: getHishtoryVersion(r)}) + err := GLOBAL_DB.UsageDataCreate( + r.Context(), + &shared.UsageData{ + UserId: userId, + DeviceId: deviceId, + LastUsed: time.Now(), + NumEntriesHandled: numEntriesHandled, + Version: getHishtoryVersion(r), + }, + ) + if err != nil { + return fmt.Errorf("db.UsageDataCreate: %w", err) + } } else { usage := usageData[0] - GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("user_id = ? AND device_id = ?", userId, deviceId).Update("last_used", time.Now()).Update("last_ip", getRemoteAddr(r)) + + if err := GLOBAL_DB.UsageDataUpdate(r.Context(), userId, deviceId, time.Now(), getRemoteAddr(r)); err != nil { + return fmt.Errorf("db.UsageDataUpdate: %w", err) + } if numEntriesHandled > 0 { - GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId) + if err := GLOBAL_DB.UsageDataUpdateNumEntriesHandled(r.Context(), userId, deviceId, numEntriesHandled); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err) + } } if usage.Version != getHishtoryVersion(r) { - GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", getHishtoryVersion(r), userId, deviceId) + if err := GLOBAL_DB.UsageDataUpdateVersion(r.Context(), userId, deviceId, getHishtoryVersion(r)); err != nil { + return fmt.Errorf("db.UsageDataUpdateVersion: %w", err) + } } } if isQuery { - GLOBAL_DB.WithContext(r.Context()).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userId, deviceId) + if err := GLOBAL_DB.UsageDataUpdateNumQueries(r.Context(), userId, deviceId); err != nil { + return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err) + } } + + return nil } func usageStatsHandler(w http.ResponseWriter, r *http.Request) { - query := ` - SELECT - MIN(devices.registration_date) as registration_date, - COUNT(DISTINCT devices.device_id) as num_devices, - SUM(usage_data.num_entries_handled) as num_history_entries, - MAX(usage_data.last_used) as last_active, - COALESCE(STRING_AGG(DISTINCT usage_data.last_ip, ', ') FILTER (WHERE usage_data.last_ip != 'Unknown' AND usage_data.last_ip != 'UnknownIp'), 'Unknown') as ip_addresses, - COALESCE(SUM(usage_data.num_queries), 0) as num_queries, - COALESCE(MAX(usage_data.last_queried), 'January 1, 1970') as last_queried, - STRING_AGG(DISTINCT usage_data.version, ', ') as versions - FROM devices - INNER JOIN usage_data ON devices.device_id = usage_data.device_id - GROUP BY devices.user_id - ORDER BY registration_date - ` - rows, err := GLOBAL_DB.WithContext(r.Context()).Raw(query).Rows() + usageData, err := GLOBAL_DB.UsageDataStats(r.Context()) if err != nil { - panic(err) + panic(fmt.Errorf("db.UsageDataStats: %w", err)) } - defer rows.Close() + tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs") tbl.WithWriter(w) - for rows.Next() { - var registrationDate time.Time - var numDevices int - var numEntries int - var lastUsedDate time.Time - var ipAddresses string - var numQueries int - var lastQueried time.Time - var versions string - err = rows.Scan(®istrationDate, &numDevices, &numEntries, &lastUsedDate, &ipAddresses, &numQueries, &lastQueried, &versions) - if err != nil { - panic(err) - } - versions = strings.ReplaceAll(strings.ReplaceAll(versions, "Unknown", ""), ", ", "") - lastQueryStr := strings.ReplaceAll(lastQueried.Format("2006-01-02"), "1970-01-01", "") - tbl.AddRow(registrationDate.Format("2006-01-02"), numDevices, numEntries, numQueries, lastUsedDate.Format("2006-01-02"), lastQueryStr, versions, ipAddresses) + for _, data := range usageData { + versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "") + lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(time.DateOnly), "1970-01-01", "") + tbl.AddRow( + data.RegistrationDate.Format(time.DateOnly), + data.NumDevices, + data.NumEntries, + data.NumQueries, + data.LastUsedDate.Format(time.DateOnly), + lastQueryStr, + versions, + data.IpAddresses, + ) } tbl.Print() } @@ -138,15 +131,15 @@ func statsHandler(w http.ResponseWriter, r *http.Request) { Total int } nep := numEntriesProcessed{} - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep)) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.UsageData{}).Select("SUM(num_entries_handled) as total").Find(&nep)) var numDbEntries int64 = 0 checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.EncHistoryEntry{}).Count(&numDbEntries)) lastWeek := time.Now().AddDate(0, 0, -7) var weeklyActiveInstalls int64 = 0 - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls)) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.UsageData{}).Where("last_used > ?", lastWeek).Count(&weeklyActiveInstalls)) var weeklyQueryUsers int64 = 0 - checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers)) + checkGormResult(GLOBAL_DB.WithContext(r.Context()).Model(&shared.UsageData{}).Where("last_queried > ?", lastWeek).Count(&weeklyQueryUsers)) var lastRegistration string = "" row := GLOBAL_DB.WithContext(r.Context()).Raw("select to_char(max(registration_date), 'DD Month YYYY HH24:MI') from devices").Row() err := row.Scan(&lastRegistration) @@ -451,13 +444,9 @@ func healthCheckHandler(w http.ResponseWriter, r *http.Request) { ReadCount: 10000, })) } else { - db, err := GLOBAL_DB.DB() + err := GLOBAL_DB.Ping() if err != nil { - panic(fmt.Sprintf("failed to get DB: %v", err)) - } - err = db.Ping() - if err != nil { - panic(fmt.Sprintf("failed to ping DB: %v", err)) + panic(fmt.Sprintf("failed to ping DB: %w", err)) } } w.Write([]byte("OK")) @@ -487,11 +476,12 @@ func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) { } func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) { - sqlDb, err := GLOBAL_DB.DB() + stats, err := GLOBAL_DB.Stats() if err != nil { panic(err) } - _, _ = fmt.Fprintf(w, "%#v", sqlDb.Stats().OpenConnections) + + _, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections) } func isTestEnvironment() bool { @@ -502,19 +492,19 @@ func isProductionEnvironment() bool { return os.Getenv("HISHTORY_ENV") == "prod" } -func OpenDB() (*gorm.DB, error) { +func OpenDB() (*database.DB, error) { if isTestEnvironment() { - db, err := gorm.Open(sqlite.Open("file::memory:?_journal_mode=WAL&cache=shared"), &gorm.Config{}) + db, err := database.OpenSQLite("file::memory:?_journal_mode=WAL&cache=shared", &gorm.Config{}) if err != nil { return nil, fmt.Errorf("failed to connect to the DB: %w", err) } - underlyingDb, err := db.DB() + underlyingDb, err := db.DB.DB() if err != nil { return nil, fmt.Errorf("failed to access underlying DB: %w", err) } underlyingDb.SetMaxOpenConns(1) db.Exec("PRAGMA journal_mode = WAL") - AddDatabaseTables(db) + db.AddDatabaseTables() return db, nil } @@ -531,41 +521,31 @@ func OpenDB() (*gorm.DB, error) { sqliteDb = os.Getenv("HISHTORY_SQLITE_DB") } - var db *gorm.DB + config := gorm.Config{Logger: customLogger} + + var db *database.DB if sqliteDb != "" { var err error - db, err = gorm.Open(sqlite.Open(sqliteDb), &gorm.Config{Logger: customLogger}) + db, err = database.OpenSQLite(sqliteDb, &config) if err != nil { return nil, fmt.Errorf("failed to connect to the DB: %w", err) } } else { + var err error postgresDb := fmt.Sprintf(PostgresDb, os.Getenv("POSTGRESQL_PASSWORD")) if os.Getenv("HISHTORY_POSTGRES_DB") != "" { postgresDb = os.Getenv("HISHTORY_POSTGRES_DB") } - sqltrace.Register("pgx", &stdlib.Driver{}, sqltrace.WithServiceName("hishtory-api")) - sqlDb, err := sqltrace.Open("pgx", postgresDb) - if err != nil { - log.Fatal(err) - } - db, err = gormtrace.Open(postgres.New(postgres.Config{Conn: sqlDb}), &gorm.Config{Logger: customLogger}) + + db, err = database.OpenPostgres(postgresDb, &config) if err != nil { return nil, fmt.Errorf("failed to connect to the DB: %w", err) } } - AddDatabaseTables(db) + db.AddDatabaseTables() return db, nil } -func AddDatabaseTables(db *gorm.DB) { - db.AutoMigrate(&shared.EncHistoryEntry{}) - db.AutoMigrate(&shared.Device{}) - db.AutoMigrate(&UsageData{}) - db.AutoMigrate(&shared.DumpRequest{}) - db.AutoMigrate(&shared.DeletionRequest{}) - db.AutoMigrate(&shared.Feedback{}) -} - func init() { if ReleaseVersion == "UNKNOWN" && !isTestEnvironment() { panic("server.go was built without a ReleaseVersion!") @@ -688,13 +668,13 @@ func InitDB() { if err != nil { panic(err) } - sqlDb, err := GLOBAL_DB.DB() + sqlDb, err := GLOBAL_DB.DB.DB() if err != nil { panic(err) } - err = sqlDb.Ping() - if err != nil { - panic(err) + + if err := GLOBAL_DB.Ping(); err != nil { + panic(fmt.Errorf("ping: %w", err)) } if isProductionEnvironment() { sqlDb.SetMaxIdleConns(10) diff --git a/internal/database/db.go b/internal/database/db.go new file mode 100644 index 0000000..bb19a88 --- /dev/null +++ b/internal/database/db.go @@ -0,0 +1,95 @@ +package database + +import ( + "database/sql" + "fmt" + "github.com/ddworken/hishtory/shared" + "github.com/jackc/pgx/v4/stdlib" + _ "github.com/lib/pq" + sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" + gormtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gorm.io/gorm.v1" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type DB struct { + *gorm.DB +} + +func OpenSQLite(dsn string, config *gorm.Config) (*DB, error) { + db, err := gorm.Open(sqlite.Open(dsn), config) + if err != nil { + return nil, fmt.Errorf("gorm.Open: %w", err) + } + + return &DB{db}, nil +} + +func OpenPostgres(dsn string, config *gorm.Config) (*DB, error) { + sqltrace.Register("pgx", &stdlib.Driver{}, sqltrace.WithServiceName("hishtory-api")) + sqlDb, err := sqltrace.Open("pgx", dsn) + if err != nil { + return nil, fmt.Errorf("sqltrace.Open: %w", err) + } + db, err := gormtrace.Open(postgres.New(postgres.Config{Conn: sqlDb}), config) + if err != nil { + return nil, fmt.Errorf("gormtrace.Open: %w", err) + } + + return &DB{db}, nil +} + +func (db *DB) AddDatabaseTables() error { + models := []any{ + &shared.EncHistoryEntry{}, + &shared.Device{}, + &shared.UsageData{}, + &shared.DumpRequest{}, + &shared.DeletionRequest{}, + &shared.Feedback{}, + } + + for _, model := range models { + if err := db.AutoMigrate(model); err != nil { + return fmt.Errorf("db.AutoMigrate: %w", err) + } + } + + return nil +} + +func (db *DB) Close() error { + rawDB, err := db.DB.DB() + if err != nil { + return fmt.Errorf("db.DB.DB: %w", err) + } + + if err := rawDB.Close(); err != nil { + return fmt.Errorf("rawDB.Close: %w", err) + } + + return nil +} + +func (db *DB) Ping() error { + rawDB, err := db.DB.DB() + if err != nil { + return fmt.Errorf("db.DB.DB: %w", err) + } + + if err := rawDB.Ping(); err != nil { + return fmt.Errorf("rawDB.Ping: %w", err) + } + + return nil +} + +func (db *DB) Stats() (sql.DBStats, error) { + rawDB, err := db.DB.DB() + if err != nil { + return sql.DBStats{}, fmt.Errorf("db.DB.DB: %w", err) + } + + return rawDB.Stats(), nil +} diff --git a/internal/database/usagedata.go b/internal/database/usagedata.go new file mode 100644 index 0000000..eb4ec94 --- /dev/null +++ b/internal/database/usagedata.go @@ -0,0 +1,135 @@ +package database + +import ( + "context" + "fmt" + "github.com/ddworken/hishtory/shared" + "time" +) + +func (db *DB) UsageDataFindByUserAndDevice(ctx context.Context, userId, deviceId string) ([]shared.UsageData, error) { + var usageData []shared.UsageData + + tx := db.DB.WithContext(ctx).Where("user_id = ? AND device_id = ?", userId, deviceId).Find(&usageData) + if tx.Error != nil { + return nil, fmt.Errorf("db.WithContext.Where.Find: %w", tx.Error) + } + + if err := db.Where("user_id = ? AND device_id = ?", userId, deviceId).First(&usageData).Error; err != nil { + return nil, fmt.Errorf("db.Where: %w", err) + } + + return usageData, nil +} + +func (db *DB) UsageDataCreate(ctx context.Context, usageData *shared.UsageData) error { + tx := db.DB.WithContext(ctx).Create(usageData) + if tx.Error != nil { + return fmt.Errorf("db.WithContext.Create: %w", tx.Error) + } + + return nil +} + +// UsageDataUpdate updates the entry for a given userID/deviceID pair with the lastUsed and lastIP values +func (db *DB) UsageDataUpdate(ctx context.Context, userId, deviceId string, lastUsed time.Time, lastIP string) error { + tx := db.DB.WithContext(ctx).Model(&shared.UsageData{}). + Where("user_id = ? AND device_id = ?", userId, deviceId). + Update("last_used", lastUsed). + Update("last_ip", lastIP) + + if tx.Error != nil { + return fmt.Errorf("db.WithContext.Model.Where.Update: %w", tx.Error) + } + + return nil +} + +func (db *DB) UsageDataUpdateNumEntriesHandled(ctx context.Context, userId, deviceId string, numEntriesHandled int) error { + tx := db.DB.WithContext(ctx).Exec("UPDATE usage_data SET num_entries_handled = COALESCE(num_entries_handled, 0) + ? WHERE user_id = ? AND device_id = ?", numEntriesHandled, userId, deviceId) + + if tx.Error != nil { + return fmt.Errorf("db.WithContext.Exec: %w", tx.Error) + } + + return nil +} + +func (db *DB) UsageDataUpdateVersion(ctx context.Context, userID, deviceID string, version string) error { + tx := db.DB.WithContext(ctx).Exec("UPDATE usage_data SET version = ? WHERE user_id = ? AND device_id = ?", version, userID, deviceID) + + if tx.Error != nil { + return fmt.Errorf("db.WithContext.Exec: %w", tx.Error) + } + + return nil +} + +func (db *DB) UsageDataUpdateNumQueries(ctx context.Context, userID, deviceID string) error { + tx := db.DB.WithContext(ctx).Exec("UPDATE usage_data SET num_queries = COALESCE(num_queries, 0) + 1, last_queried = ? WHERE user_id = ? AND device_id = ?", time.Now(), userID, deviceID) + + if tx.Error != nil { + return fmt.Errorf("db.WithContext.Exec: %w", tx.Error) + } + + return nil +} + +type UsageDataStats struct { + RegistrationDate time.Time + NumDevices int + NumEntries int + LastUsedDate time.Time + IpAddresses string + NumQueries int + LastQueried time.Time + Versions string +} + +const usageDataStatsQuery = ` + SELECT + MIN(devices.registration_date) as registration_date, + COUNT(DISTINCT devices.device_id) as num_devices, + SUM(usage_data.num_entries_handled) as num_history_entries, + MAX(usage_data.last_used) as last_active, + COALESCE(STRING_AGG(DISTINCT usage_data.last_ip, ', ') FILTER (WHERE usage_data.last_ip != 'Unknown' AND usage_data.last_ip != 'UnknownIp'), 'Unknown') as ip_addresses, + COALESCE(SUM(usage_data.num_queries), 0) as num_queries, + COALESCE(MAX(usage_data.last_queried), 'January 1, 1970') as last_queried, + STRING_AGG(DISTINCT usage_data.version, ', ') as versions + FROM devices + INNER JOIN usage_data ON devices.device_id = usage_data.device_id + GROUP BY devices.user_id + ORDER BY registration_date + ` + +func (db *DB) UsageDataStats(ctx context.Context) ([]*UsageDataStats, error) { + var resp []*UsageDataStats + + rows, err := db.DB.WithContext(ctx).Raw(usageDataStatsQuery).Rows() + if err != nil { + return nil, fmt.Errorf("db.WithContext.Raw.Rows: %w", err) + } + defer rows.Close() + + for rows.Next() { + var usageData UsageDataStats + + err := rows.Scan( + &usageData.RegistrationDate, + &usageData.NumDevices, + &usageData.NumEntries, + &usageData.LastUsedDate, + &usageData.IpAddresses, + &usageData.NumQueries, + &usageData.LastQueried, + &usageData.Versions, + ) + if err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + + resp = append(resp, &usageData) + } + + return resp, nil +} diff --git a/shared/usagedata.go b/shared/usagedata.go new file mode 100644 index 0000000..cc515fb --- /dev/null +++ b/shared/usagedata.go @@ -0,0 +1,14 @@ +package shared + +import "time" + +type UsageData struct { + UserId string `json:"user_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` + DeviceId string `json:"device_id" gorm:"not null; uniqueIndex:usageDataUniqueIndex"` + LastUsed time.Time `json:"last_used"` + LastIp string `json:"last_ip"` + NumEntriesHandled int `json:"num_entries_handled"` + LastQueried time.Time `json:"last_queried"` + NumQueries int `json:"num_queries"` + Version string `json:"version"` +}