mirror of
https://github.com/ddworken/hishtory.git
synced 2025-06-24 05:51:38 +02:00
Merge pull request #106 from lsmoura/sergio/isolated-server
Create isolated server struct that encapsulates all server logic
This commit is contained in:
commit
b478eadeae
2
.github/workflows/go-test.yml
vendored
2
.github/workflows/go-test.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
os: [ubuntu-latest, macos-latest]
|
os: [ubuntu-latest, macos-latest]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v3
|
||||||
with:
|
with:
|
||||||
|
@ -2,439 +2,32 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"html"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pprofhttp "net/http/pprof"
|
|
||||||
|
|
||||||
"github.com/DataDog/datadog-go/statsd"
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
"github.com/ddworken/hishtory/internal/database"
|
"github.com/ddworken/hishtory/internal/database"
|
||||||
"github.com/ddworken/hishtory/shared"
|
"github.com/ddworken/hishtory/internal/release"
|
||||||
|
"github.com/ddworken/hishtory/internal/server"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
"github.com/rodaine/table"
|
|
||||||
httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
|
|
||||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
|
|
||||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
|
||||||
"gopkg.in/DataDog/dd-trace-go.v1/profiler"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable"
|
PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable"
|
||||||
|
StatsdSocket = "unix:///var/run/datadog/dsd.socket"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
GLOBAL_DB *database.DB
|
GLOBAL_DB *database.DB
|
||||||
GLOBAL_STATSD *statsd.Client
|
GLOBAL_STATSD *statsd.Client
|
||||||
ReleaseVersion string = "UNKNOWN"
|
ReleaseVersion string
|
||||||
)
|
)
|
||||||
|
|
||||||
func getRequiredQueryParam(r *http.Request, queryParam string) string {
|
|
||||||
val := r.URL.Query().Get(queryParam)
|
|
||||||
if val == "" {
|
|
||||||
panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam))
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHishtoryVersion(r *http.Request) string {
|
|
||||||
return r.Header.Get("X-Hishtory-Version")
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) error {
|
|
||||||
var usageData []shared.UsageData
|
|
||||||
usageData, err := GLOBAL_DB.UsageDataFindByUserAndDevice(r.Context(), userId, deviceId)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err)
|
|
||||||
}
|
|
||||||
if len(usageData) == 0 {
|
|
||||||
err := GLOBAL_DB.CreateUsageData(
|
|
||||||
r.Context(),
|
|
||||||
&shared.UsageData{
|
|
||||||
UserId: userId,
|
|
||||||
DeviceId: deviceId,
|
|
||||||
LastUsed: time.Now(),
|
|
||||||
NumEntriesHandled: numEntriesHandled,
|
|
||||||
Version: getHishtoryVersion(r),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("db.CreateUsageData: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
usage := usageData[0]
|
|
||||||
|
|
||||||
if err := GLOBAL_DB.UpdateUsageData(r.Context(), userId, deviceId, time.Now(), getRemoteAddr(r)); err != nil {
|
|
||||||
return fmt.Errorf("db.UpdateUsageData: %w", err)
|
|
||||||
}
|
|
||||||
if numEntriesHandled > 0 {
|
|
||||||
if err := GLOBAL_DB.UpdateUsageDataForNumEntriesHandled(r.Context(), userId, deviceId, numEntriesHandled); err != nil {
|
|
||||||
return fmt.Errorf("db.UpdateUsageDataForNumEntriesHandled: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if usage.Version != getHishtoryVersion(r) {
|
|
||||||
if err := GLOBAL_DB.UpdateUsageDataClientVersion(r.Context(), userId, deviceId, getHishtoryVersion(r)); err != nil {
|
|
||||||
return fmt.Errorf("db.UpdateUsageDataClientVersion: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isQuery {
|
|
||||||
if err := GLOBAL_DB.UpdateUsageDataNumberQueries(r.Context(), userId, deviceId); err != nil {
|
|
||||||
return fmt.Errorf("db.UpdateUsageDataNumberQueries: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
usageData, err := GLOBAL_DB.UsageDataStats(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("db.UsageDataStats: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs")
|
|
||||||
tbl.WithWriter(w)
|
|
||||||
for _, data := range usageData {
|
|
||||||
versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "")
|
|
||||||
lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "")
|
|
||||||
tbl.AddRow(
|
|
||||||
data.RegistrationDate.Format(shared.DateOnly),
|
|
||||||
data.NumDevices,
|
|
||||||
data.NumEntries,
|
|
||||||
data.NumQueries,
|
|
||||||
data.LastUsedDate.Format(shared.DateOnly),
|
|
||||||
lastQueryStr,
|
|
||||||
versions,
|
|
||||||
data.IpAddresses,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
tbl.Print()
|
|
||||||
}
|
|
||||||
|
|
||||||
func statsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
numDevices, err := GLOBAL_DB.CountAllDevices(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
numEntriesProcessed, err := GLOBAL_DB.UsageDataTotal(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
numDbEntries, err := GLOBAL_DB.CountHistoryEntries(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
oneWeek := time.Hour * 24 * 7
|
|
||||||
weeklyActiveInstalls, err := GLOBAL_DB.CountActiveInstalls(r.Context(), oneWeek)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
weeklyQueryUsers, err := GLOBAL_DB.CountQueryUsers(r.Context(), oneWeek)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
lastRegistration, err := GLOBAL_DB.DateOfLastRegistration(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
_, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices)
|
|
||||||
_, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed)
|
|
||||||
_, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
|
|
||||||
_, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
|
|
||||||
_, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
|
|
||||||
_, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
data, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var entries []*shared.EncHistoryEntry
|
|
||||||
err = json.Unmarshal(data, &entries)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
|
||||||
}
|
|
||||||
fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries))
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = updateUsageData(r, entries[0].UserId, entries[0].DeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false)
|
|
||||||
|
|
||||||
devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
if len(devices) == 0 {
|
|
||||||
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
|
|
||||||
}
|
|
||||||
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
|
|
||||||
|
|
||||||
err = GLOBAL_DB.AddHistoryEntriesForAllDevices(r.Context(), devices, entries)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
|
|
||||||
}
|
|
||||||
if GLOBAL_STATSD != nil {
|
|
||||||
GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false)
|
|
||||||
historyEntries, err := GLOBAL_DB.AllHistoryEntriesForUser(r.Context(), userId)
|
|
||||||
checkGormError(err, 1)
|
|
||||||
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
|
|
||||||
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, true)
|
|
||||||
|
|
||||||
// Delete any entries that match a pending deletion request
|
|
||||||
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
for _, request := range deletionRequests {
|
|
||||||
_, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then retrieve
|
|
||||||
historyEntries, err := GLOBAL_DB.HistoryEntriesForDevice(r.Context(), deviceId, 5)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
|
|
||||||
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
|
|
||||||
// blocking the entire response. This does have a potential race condition, but that is fine.
|
|
||||||
if isProductionEnvironment() {
|
|
||||||
go func() {
|
|
||||||
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
|
|
||||||
err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId)
|
|
||||||
span.Finish(tracer.WithError(err))
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId)
|
|
||||||
if err != nil {
|
|
||||||
panic("failed to increment read counts")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if GLOBAL_STATSD != nil {
|
|
||||||
GLOBAL_STATSD.Incr("hishtory.query", []string{}, 1.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRemoteAddr(r *http.Request) string {
|
|
||||||
addr, ok := r.Header["X-Real-Ip"]
|
|
||||||
if !ok || len(addr) == 0 {
|
|
||||||
return "UnknownIp"
|
|
||||||
}
|
|
||||||
return addr[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
|
|
||||||
numDistinctUsers, err := GLOBAL_DB.DistinctUsers(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("db.DistinctUsers: %w", err))
|
|
||||||
}
|
|
||||||
if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) {
|
|
||||||
panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
|
|
||||||
existingDevicesCount, err := GLOBAL_DB.CountDevicesForUser(r.Context(), userId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
|
|
||||||
if err := GLOBAL_DB.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil {
|
|
||||||
checkGormError(err, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
if existingDevicesCount > 0 {
|
|
||||||
err := GLOBAL_DB.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
|
|
||||||
checkGormError(err, 0)
|
|
||||||
}
|
|
||||||
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false)
|
|
||||||
|
|
||||||
if GLOBAL_STATSD != nil {
|
|
||||||
GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
var dumpRequests []*shared.DumpRequest
|
|
||||||
// Filter out ones requested by the hishtory instance that sent this request
|
|
||||||
dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
|
|
||||||
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
|
|
||||||
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
|
|
||||||
data, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var entries []*shared.EncHistoryEntry
|
|
||||||
err = json.Unmarshal(data, &entries)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
|
||||||
}
|
|
||||||
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
|
|
||||||
|
|
||||||
// sanity check
|
|
||||||
for _, entry := range entries {
|
|
||||||
entry.DeviceId = requestingDeviceId
|
|
||||||
if entry.UserId != userId {
|
|
||||||
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = GLOBAL_DB.AddHistoryEntries(r.Context(), entries...)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
_ = updateUsageData(r, userId, srcDeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
commitHash := getRequiredQueryParam(r, "commit_hash")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
forcedBanner := r.URL.Query().Get("forced_banner")
|
|
||||||
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
|
|
||||||
if getHishtoryVersion(r) == "v0.160" {
|
|
||||||
w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory."))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write([]byte(html.EscapeString(forcedBanner)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userId := getRequiredQueryParam(r, "user_id")
|
|
||||||
deviceId := getRequiredQueryParam(r, "device_id")
|
|
||||||
|
|
||||||
// Increment the ReadCount
|
|
||||||
err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
// Return all the deletion requests
|
|
||||||
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
|
|
||||||
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
data, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var request shared.DeletionRequest
|
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &request); err != nil {
|
|
||||||
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
|
||||||
}
|
|
||||||
request.ReadCount = 0
|
|
||||||
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
|
|
||||||
|
|
||||||
err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if isProductionEnvironment() {
|
|
||||||
encHistoryEntryCount, err := GLOBAL_DB.CountHistoryEntries(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
if encHistoryEntryCount < 1000 {
|
|
||||||
panic("Suspiciously few enc history entries!")
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCount, err := GLOBAL_DB.CountAllDevices(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
if deviceCount < 100 {
|
|
||||||
panic("Suspiciously few devices!")
|
|
||||||
}
|
|
||||||
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
|
|
||||||
err = GLOBAL_DB.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{
|
|
||||||
EncryptedData: []byte("data"),
|
|
||||||
Nonce: []byte("nonce"),
|
|
||||||
DeviceId: "healthcheck_device_id",
|
|
||||||
UserId: "healthcheck_user_id",
|
|
||||||
Date: time.Now(),
|
|
||||||
EncryptedId: "healthcheck_enc_id",
|
|
||||||
ReadCount: 10000,
|
|
||||||
})
|
|
||||||
checkGormError(err, 0)
|
|
||||||
} else {
|
|
||||||
err := GLOBAL_DB.Ping()
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Errorf("failed to ping DB: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
|
|
||||||
panic("refusing to wipe the DB for prod")
|
|
||||||
}
|
|
||||||
if !isTestEnvironment() {
|
|
||||||
panic("refusing to wipe the DB non-test environment")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := GLOBAL_DB.Unsafe_DeleteAllHistoryEntries(r.Context())
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
stats, err := GLOBAL_DB.Stats()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTestEnvironment() bool {
|
func isTestEnvironment() bool {
|
||||||
return os.Getenv("HISHTORY_TEST") != ""
|
return os.Getenv("HISHTORY_TEST") != ""
|
||||||
}
|
}
|
||||||
@ -504,122 +97,45 @@ func OpenDB() (*database.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if ReleaseVersion == "UNKNOWN" && !isTestEnvironment() {
|
release.Version = ReleaseVersion
|
||||||
|
if release.Version == "UNKNOWN" && !isTestEnvironment() {
|
||||||
panic("server.go was built without a ReleaseVersion!")
|
panic("server.go was built without a ReleaseVersion!")
|
||||||
}
|
}
|
||||||
InitDB()
|
InitDB()
|
||||||
go runBackgroundJobs(context.Background())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func cron(ctx context.Context) error {
|
func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error {
|
||||||
err := updateReleaseVersion()
|
if err := release.UpdateReleaseVersion(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("updateReleaseVersion: %w", err)
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
err = GLOBAL_DB.Clean(ctx)
|
|
||||||
if err != nil {
|
if err := db.Clean(ctx); err != nil {
|
||||||
panic(err)
|
return fmt.Errorf("db.Clean: %w", err)
|
||||||
}
|
}
|
||||||
if GLOBAL_STATSD != nil {
|
if stats != nil {
|
||||||
err = GLOBAL_STATSD.Flush()
|
if err := stats.Flush(); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("stats.Flush: %w", err)
|
||||||
panic(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runBackgroundJobs(ctx context.Context) {
|
func runBackgroundJobs(ctx context.Context, srv *server.Server) {
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
for {
|
for {
|
||||||
err := cron(ctx)
|
err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Cron failure: %v", err)
|
fmt.Printf("Cron failure: %v", err)
|
||||||
|
|
||||||
|
// cron no longer panics, panicking here.
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
|
srv.UpdateReleaseVersion(release.Version, release.BuildUpdateInfo(release.Version))
|
||||||
time.Sleep(10 * time.Minute)
|
time.Sleep(10 * time.Minute)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func triggerCronHandler(w http.ResponseWriter, r *http.Request) {
|
func InitDB() *database.DB {
|
||||||
err := cron(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
type releaseInfo struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateReleaseVersion() error {
|
|
||||||
resp, err := http.Get("https://api.github.com/repos/ddworken/hishtory/releases/latest")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get latest release version: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read github API response body: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
var info releaseInfo
|
|
||||||
err = json.Unmarshal(respBody, &info)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse github API response: %w", err)
|
|
||||||
}
|
|
||||||
latestVersionTag := info.Name
|
|
||||||
ReleaseVersion = decrementVersionIfInvalid(latestVersionTag)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decrementVersionIfInvalid(initialVersion string) string {
|
|
||||||
// Decrements the version up to 5 times if the version doesn't have valid binaries yet.
|
|
||||||
version := initialVersion
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
updateInfo := buildUpdateInfo(version)
|
|
||||||
err := assertValidUpdate(updateInfo)
|
|
||||||
if err == nil {
|
|
||||||
fmt.Printf("Found a valid version: %v\n", version)
|
|
||||||
return version
|
|
||||||
}
|
|
||||||
fmt.Printf("Found %s to be an invalid version: %v\n", version, err)
|
|
||||||
version, err = decrementVersion(version)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err)
|
|
||||||
return initialVersion
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version)
|
|
||||||
return initialVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertValidUpdate(updateInfo shared.UpdateInfo) error {
|
|
||||||
urls := []string{updateInfo.LinuxAmd64Url, updateInfo.LinuxAmd64AttestationUrl, updateInfo.LinuxArm64Url, updateInfo.LinuxArm64AttestationUrl,
|
|
||||||
updateInfo.LinuxArm7Url, updateInfo.LinuxArm7AttestationUrl,
|
|
||||||
updateInfo.DarwinAmd64Url, updateInfo.DarwinAmd64UnsignedUrl, updateInfo.DarwinAmd64AttestationUrl,
|
|
||||||
updateInfo.DarwinArm64Url, updateInfo.DarwinArm64UnsignedUrl, updateInfo.DarwinArm64AttestationUrl}
|
|
||||||
for _, url := range urls {
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to retrieve URL %#v: %w", url, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode == 404 {
|
|
||||||
return fmt.Errorf("URL %#v returned 404", url)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InitDB() {
|
|
||||||
var err error
|
var err error
|
||||||
GLOBAL_DB, err = OpenDB()
|
GLOBAL_DB, err = OpenDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -639,233 +155,34 @@ func InitDB() {
|
|||||||
panic(fmt.Errorf("failed to set max idle conns: %w", err))
|
panic(fmt.Errorf("failed to set max idle conns: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func decrementVersion(version string) (string, error) {
|
return GLOBAL_DB
|
||||||
if version == "UNKNOWN" {
|
|
||||||
return "", fmt.Errorf("cannot decrement UNKNOWN")
|
|
||||||
}
|
|
||||||
parts := strings.Split(version, ".")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return "", fmt.Errorf("invalid version: %s", version)
|
|
||||||
}
|
|
||||||
versionNumber, err := strconv.Atoi(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("invalid version: %s", version)
|
|
||||||
}
|
|
||||||
return parts[0] + "." + strconv.Itoa(versionNumber-1), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildUpdateInfo(version string) shared.UpdateInfo {
|
|
||||||
return shared.UpdateInfo{
|
|
||||||
LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version),
|
|
||||||
LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version),
|
|
||||||
LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version),
|
|
||||||
LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version),
|
|
||||||
LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version),
|
|
||||||
LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version),
|
|
||||||
DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version),
|
|
||||||
DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version),
|
|
||||||
DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version),
|
|
||||||
DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version),
|
|
||||||
DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version),
|
|
||||||
DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version),
|
|
||||||
Version: version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
updateInfo := buildUpdateInfo(ReleaseVersion)
|
|
||||||
resp, err := json.Marshal(updateInfo)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
w.Write(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// returns "OK" unless there is a current SLSA bug
|
|
||||||
v := getHishtoryVersion(r)
|
|
||||||
if !strings.Contains(v, "v0.") {
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
vNum, err := strconv.Atoi(strings.Split(v, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if vNum < 159 {
|
|
||||||
w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func feedbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
data, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var feedback shared.Feedback
|
|
||||||
err = json.Unmarshal(data, &feedback)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
|
|
||||||
}
|
|
||||||
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
|
|
||||||
err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback)
|
|
||||||
checkGormError(err, 0)
|
|
||||||
|
|
||||||
if GLOBAL_STATSD != nil {
|
|
||||||
GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Length", "0")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
type loggedResponseData struct {
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
type loggingResponseWriter struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
responseData *loggedResponseData
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *loggingResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
size, err := r.ResponseWriter.Write(b)
|
|
||||||
r.responseData.size += size
|
|
||||||
return size, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *loggingResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
r.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFunctionName(temp interface{}) string {
|
|
||||||
strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".")
|
|
||||||
return strs[len(strs)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
func withLogging(h http.HandlerFunc) http.Handler {
|
|
||||||
logFn := func(rw http.ResponseWriter, r *http.Request) {
|
|
||||||
var responseData loggedResponseData
|
|
||||||
lrw := loggingResponseWriter{
|
|
||||||
ResponseWriter: rw,
|
|
||||||
responseData: &responseData,
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
span, ctx := tracer.StartSpanFromContext(
|
|
||||||
r.Context(),
|
|
||||||
getFunctionName(h),
|
|
||||||
tracer.SpanType(ext.SpanTypeSQL),
|
|
||||||
tracer.ServiceName("hishtory-api"),
|
|
||||||
)
|
|
||||||
defer span.Finish()
|
|
||||||
|
|
||||||
h(&lrw, r.WithContext(ctx))
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
|
|
||||||
if GLOBAL_STATSD != nil {
|
|
||||||
GLOBAL_STATSD.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0)
|
|
||||||
GLOBAL_STATSD.Incr("hishtory.request", []string{}, 1.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return http.HandlerFunc(logFn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func byteCountToString(b int) string {
|
|
||||||
const unit = 1000
|
|
||||||
if b < unit {
|
|
||||||
return fmt.Sprintf("%d B", b)
|
|
||||||
}
|
|
||||||
div, exp := int64(unit), 0
|
|
||||||
for n := b / unit; n >= unit; n /= unit {
|
|
||||||
div *= unit
|
|
||||||
exp++
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
|
|
||||||
}
|
|
||||||
|
|
||||||
func configureObservability(mux *httptrace.ServeMux) func() {
|
|
||||||
// Profiler
|
|
||||||
err := profiler.Start(
|
|
||||||
profiler.WithService("hishtory-api"),
|
|
||||||
profiler.WithVersion(ReleaseVersion),
|
|
||||||
profiler.WithAPIKey(os.Getenv("DD_API_KEY")),
|
|
||||||
profiler.WithUDS("/var/run/datadog/apm.socket"),
|
|
||||||
profiler.WithProfileTypes(
|
|
||||||
profiler.CPUProfile,
|
|
||||||
profiler.HeapProfile,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to start DataDog profiler: %v\n", err)
|
|
||||||
}
|
|
||||||
// Tracer
|
|
||||||
tracer.Start(
|
|
||||||
tracer.WithRuntimeMetrics(),
|
|
||||||
tracer.WithService("hishtory-api"),
|
|
||||||
tracer.WithUDS("/var/run/datadog/apm.socket"),
|
|
||||||
)
|
|
||||||
defer tracer.Stop()
|
|
||||||
// Stats
|
|
||||||
ddStats, err := statsd.New("unix:///var/run/datadog/dsd.socket")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Failed to start DataDog statsd: %v\n", err)
|
|
||||||
}
|
|
||||||
GLOBAL_STATSD = ddStats
|
|
||||||
// Pprof
|
|
||||||
mux.HandleFunc("/debug/pprof/", pprofhttp.Index)
|
|
||||||
mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline)
|
|
||||||
mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile)
|
|
||||||
mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol)
|
|
||||||
mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace)
|
|
||||||
|
|
||||||
// Func to stop all of the above
|
|
||||||
return func() {
|
|
||||||
profiler.Stop()
|
|
||||||
tracer.Stop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
mux := httptrace.NewServeMux()
|
s, err := statsd.New(StatsdSocket)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start DataDog statsd: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
if isProductionEnvironment() {
|
// TODO: remove this global once we have a better way to pass it around
|
||||||
defer configureObservability(mux)()
|
GLOBAL_STATSD = s
|
||||||
go func() {
|
|
||||||
if err := GLOBAL_DB.DeepClean(context.Background()); err != nil {
|
srv := server.NewServer(
|
||||||
|
GLOBAL_DB,
|
||||||
|
server.WithStatsd(s),
|
||||||
|
server.WithReleaseVersion(release.Version),
|
||||||
|
server.IsTestEnvironment(isTestEnvironment()),
|
||||||
|
server.IsProductionEnvironment(isProductionEnvironment()),
|
||||||
|
server.WithCron(cron),
|
||||||
|
server.WithUpdateInfo(release.BuildUpdateInfo(release.Version)),
|
||||||
|
)
|
||||||
|
|
||||||
|
go runBackgroundJobs(context.Background(), srv)
|
||||||
|
|
||||||
|
if err := srv.Run(context.Background(), ":8080"); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler))
|
|
||||||
mux.Handle("/api/v1/get-dump-requests", withLogging(apiGetPendingDumpRequestsHandler))
|
|
||||||
mux.Handle("/api/v1/submit-dump", withLogging(apiSubmitDumpHandler))
|
|
||||||
mux.Handle("/api/v1/query", withLogging(apiQueryHandler))
|
|
||||||
mux.Handle("/api/v1/bootstrap", withLogging(apiBootstrapHandler))
|
|
||||||
mux.Handle("/api/v1/register", withLogging(apiRegisterHandler))
|
|
||||||
mux.Handle("/api/v1/banner", withLogging(apiBannerHandler))
|
|
||||||
mux.Handle("/api/v1/download", withLogging(apiDownloadHandler))
|
|
||||||
mux.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler))
|
|
||||||
mux.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler))
|
|
||||||
mux.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler))
|
|
||||||
mux.Handle("/api/v1/slsa-status", withLogging(slsaStatusHandler))
|
|
||||||
mux.Handle("/api/v1/feedback", withLogging(feedbackHandler))
|
|
||||||
mux.Handle("/healthcheck", withLogging(healthCheckHandler))
|
|
||||||
mux.Handle("/internal/api/v1/usage-stats", withLogging(usageStatsHandler))
|
|
||||||
mux.Handle("/internal/api/v1/stats", withLogging(statsHandler))
|
|
||||||
if isTestEnvironment() {
|
|
||||||
mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler))
|
|
||||||
mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler))
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Listening on localhost:8080")
|
|
||||||
log.Fatal(http.ListenAndServe(":8080", mux))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkGormResult(result *gorm.DB) {
|
func checkGormResult(result *gorm.DB) {
|
||||||
@ -881,17 +198,5 @@ func checkGormError(err error, skip int) {
|
|||||||
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err))
|
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMaximumNumberOfAllowedUsers() int {
|
|
||||||
maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS")
|
|
||||||
if maxNumUsersStr == "" {
|
|
||||||
return math.MaxInt
|
|
||||||
}
|
|
||||||
maxNumUsers, err := strconv.Atoi(maxNumUsersStr)
|
|
||||||
if err != nil {
|
|
||||||
return math.MaxInt
|
|
||||||
}
|
|
||||||
return maxNumUsers
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?
|
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?
|
||||||
// TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test
|
// TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ddworken/hishtory/shared"
|
"github.com/ddworken/hishtory/shared"
|
||||||
"github.com/jackc/pgx/v4/stdlib"
|
"github.com/jackc/pgx/v4/stdlib"
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
127
internal/release/release.go
Normal file
127
internal/release/release.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package release
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/ddworken/hishtory/shared"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Version = "UNKNOWN"
|
||||||
|
|
||||||
|
type releaseInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const releaseURL = "https://api.github.com/repos/ddworken/hishtory/releases/latest"
|
||||||
|
|
||||||
|
func UpdateReleaseVersion() error {
|
||||||
|
resp, err := http.Get(releaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get latest release version: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read github API response body: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
var info releaseInfo
|
||||||
|
err = json.Unmarshal(respBody, &info)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse github API response: %w", err)
|
||||||
|
}
|
||||||
|
latestVersionTag := info.Name
|
||||||
|
Version = decrementVersionIfInvalid(latestVersionTag)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildUpdateInfo(version string) shared.UpdateInfo {
|
||||||
|
return shared.UpdateInfo{
|
||||||
|
LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version),
|
||||||
|
LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version),
|
||||||
|
LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version),
|
||||||
|
LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version),
|
||||||
|
LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version),
|
||||||
|
LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version),
|
||||||
|
DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version),
|
||||||
|
DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version),
|
||||||
|
DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version),
|
||||||
|
DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version),
|
||||||
|
DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version),
|
||||||
|
DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version),
|
||||||
|
Version: version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decrementVersionIfInvalid(initialVersion string) string {
|
||||||
|
// Decrements the version up to 5 times if the version doesn't have valid binaries yet.
|
||||||
|
version := initialVersion
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
updateInfo := BuildUpdateInfo(version)
|
||||||
|
err := assertValidUpdate(updateInfo)
|
||||||
|
if err == nil {
|
||||||
|
fmt.Printf("Found a valid version: %v\n", version)
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
fmt.Printf("Found %s to be an invalid version: %v\n", version, err)
|
||||||
|
version, err = decrementVersion(version)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err)
|
||||||
|
return initialVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version)
|
||||||
|
return initialVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertValidUpdate(updateInfo shared.UpdateInfo) error {
|
||||||
|
urls := []string{
|
||||||
|
updateInfo.LinuxAmd64Url,
|
||||||
|
updateInfo.LinuxAmd64AttestationUrl,
|
||||||
|
updateInfo.LinuxArm64Url,
|
||||||
|
updateInfo.LinuxArm64AttestationUrl,
|
||||||
|
updateInfo.LinuxArm7Url,
|
||||||
|
updateInfo.LinuxArm7AttestationUrl,
|
||||||
|
updateInfo.DarwinAmd64Url,
|
||||||
|
updateInfo.DarwinAmd64UnsignedUrl,
|
||||||
|
updateInfo.DarwinAmd64AttestationUrl,
|
||||||
|
updateInfo.DarwinArm64Url,
|
||||||
|
updateInfo.DarwinArm64UnsignedUrl,
|
||||||
|
updateInfo.DarwinArm64AttestationUrl,
|
||||||
|
}
|
||||||
|
for _, url := range urls {
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to retrieve URL %#v: %w", url, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
return fmt.Errorf("URL %#v returned 404", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decrementVersion(version string) (string, error) {
|
||||||
|
if version == "UNKNOWN" {
|
||||||
|
return "", fmt.Errorf("cannot decrement UNKNOWN")
|
||||||
|
}
|
||||||
|
parts := strings.Split(version, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", fmt.Errorf("invalid version: %s", version)
|
||||||
|
}
|
||||||
|
versionNumber, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("invalid version: %s", version)
|
||||||
|
}
|
||||||
|
return parts[0] + "." + strconv.Itoa(versionNumber-1), nil
|
||||||
|
}
|
33
internal/release/release_test.go
Normal file
33
internal/release/release_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package release
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ddworken/hishtory/shared/testutils"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateReleaseVersion(t *testing.T) {
|
||||||
|
if !testutils.IsOnline() {
|
||||||
|
t.Skip("skipping because we're currently offline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that ReleaseVersion hasn't been set yet
|
||||||
|
if Version != "UNKNOWN" {
|
||||||
|
t.Fatalf("initial ReleaseVersion isn't as expected: %#v", Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update it
|
||||||
|
err := UpdateReleaseVersion()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("updateReleaseVersion failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If ReleaseVersion is still unknown, skip because we're getting rate limited
|
||||||
|
if Version == "UNKNOWN" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
// Otherwise, check that the new value looks reasonable
|
||||||
|
if !strings.HasPrefix(Version, "v0.") {
|
||||||
|
t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", Version)
|
||||||
|
}
|
||||||
|
}
|
239
internal/server/api.go
Normal file
239
internal/server/api.go
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ddworken/hishtory/shared"
|
||||||
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var entries []*shared.EncHistoryEntry
|
||||||
|
err = json.Unmarshal(data, &entries)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
||||||
|
}
|
||||||
|
fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries))
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add these to the context in a middleware
|
||||||
|
version := getHishtoryVersion(r)
|
||||||
|
remoteIPAddr := getRemoteAddr(r)
|
||||||
|
|
||||||
|
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil {
|
||||||
|
fmt.Printf("updateUsageData: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
if len(devices) == 0 {
|
||||||
|
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
|
||||||
|
}
|
||||||
|
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
|
||||||
|
|
||||||
|
err = s.db.AddHistoryEntriesForAllDevices(r.Context(), devices, entries)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
|
||||||
|
}
|
||||||
|
if s.statsd != nil {
|
||||||
|
s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
|
||||||
|
// TODO: add these to the context in a middleware
|
||||||
|
version := getHishtoryVersion(r)
|
||||||
|
remoteIPAddr := getRemoteAddr(r)
|
||||||
|
|
||||||
|
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil {
|
||||||
|
fmt.Printf("updateUsageData: %v\n", err)
|
||||||
|
}
|
||||||
|
historyEntries, err := s.db.AllHistoryEntriesForUser(r.Context(), userId)
|
||||||
|
checkGormError(err)
|
||||||
|
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
|
||||||
|
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiQueryHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
|
||||||
|
// TODO: add these to the context in a middleware
|
||||||
|
version := getHishtoryVersion(r)
|
||||||
|
remoteIPAddr := getRemoteAddr(r)
|
||||||
|
|
||||||
|
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil {
|
||||||
|
fmt.Printf("updateUsageData: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete any entries that match a pending deletion request
|
||||||
|
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
||||||
|
checkGormError(err)
|
||||||
|
for _, request := range deletionRequests {
|
||||||
|
_, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request)
|
||||||
|
checkGormError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then retrieve
|
||||||
|
historyEntries, err := s.db.HistoryEntriesForDevice(r.Context(), deviceId, 5)
|
||||||
|
checkGormError(err)
|
||||||
|
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
|
||||||
|
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
|
||||||
|
// blocking the entire response. This does have a potential race condition, but that is fine.
|
||||||
|
if s.isProductionEnvironment {
|
||||||
|
go func() {
|
||||||
|
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
|
||||||
|
err := s.db.IncrementEntryReadCountsForDevice(ctx, deviceId)
|
||||||
|
span.Finish(tracer.WithError(err))
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
err := s.db.IncrementEntryReadCountsForDevice(ctx, deviceId)
|
||||||
|
if err != nil {
|
||||||
|
panic("failed to increment read counts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.statsd != nil {
|
||||||
|
s.statsd.Incr("hishtory.query", []string{}, 1.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
|
||||||
|
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
|
||||||
|
data, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var entries []*shared.EncHistoryEntry
|
||||||
|
err = json.Unmarshal(data, &entries)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
||||||
|
}
|
||||||
|
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
|
||||||
|
|
||||||
|
// sanity check
|
||||||
|
for _, entry := range entries {
|
||||||
|
entry.DeviceId = requestingDeviceId
|
||||||
|
if entry.UserId != userId {
|
||||||
|
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.db.AddHistoryEntries(r.Context(), entries...)
|
||||||
|
checkGormError(err)
|
||||||
|
err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
// TODO: add these to the context in a middleware
|
||||||
|
version := getHishtoryVersion(r)
|
||||||
|
remoteIPAddr := getRemoteAddr(r)
|
||||||
|
|
||||||
|
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil {
|
||||||
|
fmt.Printf("updateUsageData: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiBannerHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
commitHash := getRequiredQueryParam(r, "commit_hash")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
forcedBanner := r.URL.Query().Get("forced_banner")
|
||||||
|
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
|
||||||
|
if getHishtoryVersion(r) == "v0.160" {
|
||||||
|
w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory."))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte(html.EscapeString(forcedBanner)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
var dumpRequests []*shared.DumpRequest
|
||||||
|
// Filter out ones requested by the hishtory instance that sent this request
|
||||||
|
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := json.NewEncoder(w).Encode(s.updateInfo)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to JSON marshall the update info: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
|
||||||
|
numDistinctUsers, err := s.db.DistinctUsers(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("db.DistinctUsers: %w", err))
|
||||||
|
}
|
||||||
|
if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) {
|
||||||
|
panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
|
||||||
|
existingDevicesCount, err := s.db.CountDevicesForUser(r.Context(), userId)
|
||||||
|
checkGormError(err)
|
||||||
|
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
|
||||||
|
if err := s.db.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil {
|
||||||
|
checkGormError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if existingDevicesCount > 0 {
|
||||||
|
err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
|
||||||
|
checkGormError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add these to the context in a middleware
|
||||||
|
version := getHishtoryVersion(r)
|
||||||
|
remoteIPAddr := getRemoteAddr(r)
|
||||||
|
|
||||||
|
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil {
|
||||||
|
fmt.Printf("updateUsageData: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.statsd != nil {
|
||||||
|
s.statsd.Incr("hishtory.register", []string{}, 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
82
internal/server/middleware.go
Normal file
82
internal/server/middleware.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
|
||||||
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type loggedResponseData struct {
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type loggingResponseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
responseData *loggedResponseData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *loggingResponseWriter) Write(b []byte) (int, error) {
|
||||||
|
size, err := r.ResponseWriter.Write(b)
|
||||||
|
r.responseData.size += size
|
||||||
|
return size, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *loggingResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
r.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFunctionName(temp interface{}) string {
|
||||||
|
strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".")
|
||||||
|
return strs[len(strs)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func byteCountToString(b int) string {
|
||||||
|
const unit = 1000
|
||||||
|
if b < unit {
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
div, exp := int64(unit), 0
|
||||||
|
for n := b / unit; n >= unit; n /= unit {
|
||||||
|
div *= unit
|
||||||
|
exp++
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
|
||||||
|
}
|
||||||
|
|
||||||
|
type Middleware func(http.HandlerFunc) http.Handler
|
||||||
|
|
||||||
|
func withLogging(s *statsd.Client) Middleware {
|
||||||
|
return func(h http.HandlerFunc) http.Handler {
|
||||||
|
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
var responseData loggedResponseData
|
||||||
|
lrw := loggingResponseWriter{
|
||||||
|
ResponseWriter: rw,
|
||||||
|
responseData: &responseData,
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
span, ctx := tracer.StartSpanFromContext(
|
||||||
|
r.Context(),
|
||||||
|
getFunctionName(h),
|
||||||
|
tracer.SpanType(ext.SpanTypeSQL),
|
||||||
|
tracer.ServiceName("hishtory-api"),
|
||||||
|
)
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
h.ServeHTTP(&lrw, r.WithContext(ctx))
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
|
||||||
|
if s != nil {
|
||||||
|
s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0)
|
||||||
|
s.Incr("hishtory.request", []string{}, 1.0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,11 +1,13 @@
|
|||||||
package main
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"github.com/ddworken/hishtory/internal/database"
|
"github.com/ddworken/hishtory/internal/database"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -21,9 +23,32 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var DB *database.DB
|
||||||
|
|
||||||
|
const testDBDSN = "file::memory:?_journal_mode=WAL&cache=shared"
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
// setup test database
|
||||||
|
db, err := database.OpenSQLite(testDBDSN, &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to connect to the DB: %w", err))
|
||||||
|
}
|
||||||
|
underlyingDb, err := db.DB.DB()
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to access underlying DB: %w", err))
|
||||||
|
}
|
||||||
|
underlyingDb.SetMaxOpenConns(1)
|
||||||
|
db.Exec("PRAGMA journal_mode = WAL")
|
||||||
|
db.AddDatabaseTables()
|
||||||
|
|
||||||
|
DB = db
|
||||||
|
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
func TestESubmitThenQuery(t *testing.T) {
|
func TestESubmitThenQuery(t *testing.T) {
|
||||||
// Set up
|
// Set up
|
||||||
InitDB()
|
s := NewServer(DB)
|
||||||
|
|
||||||
// Register a few devices
|
// Register a few devices
|
||||||
userId := data.UserId("key")
|
userId := data.UserId("key")
|
||||||
@ -32,11 +57,11 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
otherUser := data.UserId("otherkey")
|
otherUser := data.UserId("otherkey")
|
||||||
otherDev := uuid.Must(uuid.NewRandom()).String()
|
otherDev := uuid.Must(uuid.NewRandom()).String()
|
||||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
|
|
||||||
// Submit a few entries for different devices
|
// Submit a few entries for different devices
|
||||||
entry := testutils.MakeFakeHistoryEntry("ls ~/")
|
entry := testutils.MakeFakeHistoryEntry("ls ~/")
|
||||||
@ -45,12 +70,12 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// Query for device id 1
|
// Query for device id 1
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err := io.ReadAll(res.Body)
|
respBody, err := io.ReadAll(res.Body)
|
||||||
@ -79,7 +104,7 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
// Same for device id 2
|
// Same for device id 2
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -107,7 +132,7 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
// Bootstrap handler should return 2 entries, one for each device
|
// Bootstrap handler should return 2 entries, one for each device
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil)
|
||||||
apiBootstrapHandler(w, searchReq)
|
s.apiBootstrapHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -118,12 +143,12 @@ func TestESubmitThenQuery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, GLOBAL_DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDumpRequestAndResponse(t *testing.T) {
|
func TestDumpRequestAndResponse(t *testing.T) {
|
||||||
// Set up
|
// Set up
|
||||||
InitDB()
|
s := NewServer(DB)
|
||||||
|
|
||||||
// Register a first device for two different users
|
// Register a first device for two different users
|
||||||
userId := data.UserId("dkey")
|
userId := data.UserId("dkey")
|
||||||
@ -133,17 +158,17 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
otherDev1 := uuid.Must(uuid.NewRandom()).String()
|
otherDev1 := uuid.Must(uuid.NewRandom()).String()
|
||||||
otherDev2 := uuid.Must(uuid.NewRandom()).String()
|
otherDev2 := uuid.Must(uuid.NewRandom()).String()
|
||||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
|
|
||||||
// Query for dump requests, there should be one for userId
|
// Query for dump requests, there should be one for userId
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err := io.ReadAll(res.Body)
|
respBody, err := io.ReadAll(res.Body)
|
||||||
@ -163,7 +188,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
|
|
||||||
// And one for otherUser
|
// And one for otherUser
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -183,7 +208,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
|
|
||||||
// And none if we query for a user ID that doesn't exit
|
// And none if we query for a user ID that doesn't exit
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -193,7 +218,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
|
|
||||||
// And none for a missing user ID
|
// And none for a missing user ID
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -211,11 +236,11 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2})
|
reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody))
|
submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody))
|
||||||
apiSubmitDumpHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// Check that the dump request is no longer there for userId for either device ID
|
// Check that the dump request is no longer there for userId for either device ID
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -226,7 +251,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
|
|
||||||
// The other user
|
// The other user
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -236,7 +261,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
|
|
||||||
// But it is there for the other user
|
// But it is there for the other user
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -257,7 +282,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
// And finally, query to ensure that the dumped entries are in the DB
|
// And finally, query to ensure that the dumped entries are in the DB
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -285,44 +310,12 @@ func TestDumpRequestAndResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, GLOBAL_DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateReleaseVersion(t *testing.T) {
|
|
||||||
if !testutils.IsOnline() {
|
|
||||||
t.Skip("skipping because we're currently offline")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up
|
|
||||||
InitDB()
|
|
||||||
|
|
||||||
// Check that ReleaseVersion hasn't been set yet
|
|
||||||
if ReleaseVersion != "UNKNOWN" {
|
|
||||||
t.Fatalf("initial ReleaseVersion isn't as expected: %#v", ReleaseVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update it
|
|
||||||
err := updateReleaseVersion()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("updateReleaseVersion failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If ReleaseVersion is still unknown, skip because we're getting rate limited
|
|
||||||
if ReleaseVersion == "UNKNOWN" {
|
|
||||||
t.Skip()
|
|
||||||
}
|
|
||||||
// Otherwise, check that the new value looks reasonable
|
|
||||||
if !strings.HasPrefix(ReleaseVersion, "v0.") {
|
|
||||||
t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", ReleaseVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
|
||||||
assertNoLeakedConnections(t, GLOBAL_DB)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeletionRequests(t *testing.T) {
|
func TestDeletionRequests(t *testing.T) {
|
||||||
// Set up
|
// Set up
|
||||||
InitDB()
|
s := NewServer(DB)
|
||||||
|
|
||||||
// Register two devices for two different users
|
// Register two devices for two different users
|
||||||
userId := data.UserId("dkey")
|
userId := data.UserId("dkey")
|
||||||
@ -332,13 +325,13 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
otherDev1 := uuid.Must(uuid.NewRandom()).String()
|
otherDev1 := uuid.Must(uuid.NewRandom()).String()
|
||||||
otherDev2 := uuid.Must(uuid.NewRandom()).String()
|
otherDev2 := uuid.Must(uuid.NewRandom()).String()
|
||||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
|
|
||||||
// Add an entry for user1
|
// Add an entry for user1
|
||||||
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
|
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
|
||||||
@ -348,7 +341,7 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// And another entry for user1
|
// And another entry for user1
|
||||||
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
||||||
@ -358,7 +351,7 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
|
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// And an entry for user2 that has the same timestamp as the previous entry
|
// And an entry for user2 that has the same timestamp as the previous entry
|
||||||
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
|
||||||
@ -369,12 +362,12 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
|
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// Query for device id 1
|
// Query for device id 1
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err := io.ReadAll(res.Body)
|
respBody, err := io.ReadAll(res.Body)
|
||||||
@ -413,13 +406,13 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
reqBody, err = json.Marshal(delReq)
|
reqBody, err = json.Marshal(delReq)
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
addDeletionRequestHandler(httptest.NewRecorder(), req)
|
s.addDeletionRequestHandler(httptest.NewRecorder(), req)
|
||||||
|
|
||||||
// Query again for device id 1 and get a single result
|
// Query again for device id 1 and get a single result
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -447,7 +440,7 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
// Query for user 2
|
// Query for user 2
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
|
||||||
apiQueryHandler(w, searchReq)
|
s.apiQueryHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -475,7 +468,7 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
// Query for deletion requests
|
// Query for deletion requests
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
||||||
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
getDeletionRequestsHandler(w, searchReq)
|
s.getDeletionRequestsHandler(w, searchReq)
|
||||||
res = w.Result()
|
res = w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
respBody, err = io.ReadAll(res.Body)
|
respBody, err = io.ReadAll(res.Body)
|
||||||
@ -500,12 +493,13 @@ func TestDeletionRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, GLOBAL_DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHealthcheck(t *testing.T) {
|
func TestHealthcheck(t *testing.T) {
|
||||||
|
s := NewServer(DB)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
if w.Code != 200 {
|
if w.Code != 200 {
|
||||||
t.Fatalf("expected 200 resp code for healthCheckHandler")
|
t.Fatalf("expected 200 resp code for healthCheckHandler")
|
||||||
}
|
}
|
||||||
@ -518,41 +512,47 @@ func TestHealthcheck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assert that we aren't leaking connections
|
// Assert that we aren't leaking connections
|
||||||
assertNoLeakedConnections(t, GLOBAL_DB)
|
assertNoLeakedConnections(t, DB)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLimitRegistrations(t *testing.T) {
|
func TestLimitRegistrations(t *testing.T) {
|
||||||
// Set up
|
// Set up
|
||||||
InitDB()
|
s := NewServer(DB)
|
||||||
checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries"))
|
|
||||||
checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices"))
|
if resp := DB.Exec("DELETE FROM enc_history_entries"); resp.Error != nil {
|
||||||
|
t.Fatalf("failed to delete enc_history_entries: %v", resp.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp := DB.Exec("DELETE FROM devices"); resp.Error != nil {
|
||||||
|
t.Fatalf("failed to delete devices: %v", resp.Error)
|
||||||
|
}
|
||||||
defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")()
|
defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")()
|
||||||
os.Setenv("HISHTORY_MAX_NUM_USERS", "2")
|
os.Setenv("HISHTORY_MAX_NUM_USERS", "2")
|
||||||
|
|
||||||
// Register three devices across two users
|
// Register three devices across two users
|
||||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
|
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
|
|
||||||
// And this next one should fail since it is a new user
|
// And this next one should fail since it is a new user
|
||||||
defer func() { _ = recover() }()
|
defer func() { _ = recover() }()
|
||||||
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil)
|
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
t.Errorf("expected panic")
|
t.Errorf("expected panic")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCleanDatabaseNoErrors(t *testing.T) {
|
func TestCleanDatabaseNoErrors(t *testing.T) {
|
||||||
// Init
|
// Init
|
||||||
InitDB()
|
s := NewServer(DB)
|
||||||
|
|
||||||
// Create a user and an entry
|
// Create a user and an entry
|
||||||
userId := data.UserId("dkey")
|
userId := data.UserId("dkey")
|
||||||
devId1 := uuid.Must(uuid.NewRandom()).String()
|
devId1 := uuid.Must(uuid.NewRandom()).String()
|
||||||
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
|
||||||
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
|
||||||
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
|
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
|
||||||
entry1.DeviceId = devId1
|
entry1.DeviceId = devId1
|
||||||
encEntry, err := data.EncryptHistoryEntry("dkey", entry1)
|
encEntry, err := data.EncryptHistoryEntry("dkey", entry1)
|
||||||
@ -560,10 +560,10 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
|
|||||||
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
|
||||||
testutils.Check(t, err)
|
testutils.Check(t, err)
|
||||||
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
|
||||||
apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
|
||||||
|
|
||||||
// Call cleanDatabase and just check that there are no panics
|
// Call cleanDatabase and just check that there are no panics
|
||||||
testutils.Check(t, GLOBAL_DB.Clean(context.TODO()))
|
testutils.Check(t, DB.Clean(context.TODO()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNoLeakedConnections(t *testing.T, db *database.DB) {
|
func assertNoLeakedConnections(t *testing.T, db *database.DB) {
|
374
internal/server/srv.go
Normal file
374
internal/server/srv.go
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DataDog/datadog-go/statsd"
|
||||||
|
"github.com/ddworken/hishtory/internal/database"
|
||||||
|
"github.com/ddworken/hishtory/shared"
|
||||||
|
"github.com/rodaine/table"
|
||||||
|
httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
db *database.DB
|
||||||
|
statsd *statsd.Client
|
||||||
|
|
||||||
|
isProductionEnvironment bool
|
||||||
|
isTestEnvironment bool
|
||||||
|
releaseVersion string
|
||||||
|
cronFn CronFn
|
||||||
|
updateInfo shared.UpdateInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type CronFn func(ctx context.Context, db *database.DB, stats *statsd.Client) error
|
||||||
|
type Option func(*Server)
|
||||||
|
|
||||||
|
func WithStatsd(statsd *statsd.Client) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.statsd = statsd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithReleaseVersion(releaseVersion string) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.releaseVersion = releaseVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCron(cronFn CronFn) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.cronFn = cronFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithUpdateInfo(updateInfo shared.UpdateInfo) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.updateInfo = updateInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsProductionEnvironment(v bool) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.isProductionEnvironment = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsTestEnvironment(v bool) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.isTestEnvironment = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(db *database.DB, options ...Option) *Server {
|
||||||
|
srv := Server{db: db}
|
||||||
|
for _, option := range options {
|
||||||
|
option(&srv)
|
||||||
|
}
|
||||||
|
return &srv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Run(ctx context.Context, addr string) error {
|
||||||
|
mux := httptrace.NewServeMux()
|
||||||
|
|
||||||
|
if s.isProductionEnvironment {
|
||||||
|
defer configureObservability(mux, s.releaseVersion)()
|
||||||
|
go func() {
|
||||||
|
if err := s.db.DeepClean(ctx); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
loggerMiddleware := withLogging(s.statsd)
|
||||||
|
|
||||||
|
mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler))
|
||||||
|
mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler))
|
||||||
|
mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler))
|
||||||
|
mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler))
|
||||||
|
mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler))
|
||||||
|
mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler))
|
||||||
|
mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler))
|
||||||
|
mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler))
|
||||||
|
mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler))
|
||||||
|
mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler))
|
||||||
|
mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler))
|
||||||
|
mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler))
|
||||||
|
mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler))
|
||||||
|
mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler))
|
||||||
|
mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler))
|
||||||
|
mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler))
|
||||||
|
if s.isTestEnvironment {
|
||||||
|
mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler))
|
||||||
|
mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler))
|
||||||
|
}
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: mux,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Listening on %s\n", addr)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil {
|
||||||
|
if !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return fmt.Errorf("http.ListenAndServe: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) UpdateReleaseVersion(v string, updateInfo shared.UpdateInfo) {
|
||||||
|
s.releaseVersion = v
|
||||||
|
s.updateInfo = updateInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userId := getRequiredQueryParam(r, "user_id")
|
||||||
|
deviceId := getRequiredQueryParam(r, "device_id")
|
||||||
|
|
||||||
|
// Increment the ReadCount
|
||||||
|
err := s.db.DeletionRequestInc(r.Context(), userId, deviceId)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
// Return all the deletion requests
|
||||||
|
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
|
||||||
|
checkGormError(err)
|
||||||
|
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
|
||||||
|
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var request shared.DeletionRequest
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &request); err != nil {
|
||||||
|
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
|
||||||
|
}
|
||||||
|
request.ReadCount = 0
|
||||||
|
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
|
||||||
|
|
||||||
|
err = s.db.DeletionRequestCreate(r.Context(), &request)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) triggerCronHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
err := s.cronFn(r.Context(), s.db, s.statsd)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// returns "OK" unless there is a current SLSA bug
|
||||||
|
v := getHishtoryVersion(r)
|
||||||
|
if !strings.Contains(v, "v0.") {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
vNum, err := strconv.Atoi(strings.Split(v, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if vNum < 159 {
|
||||||
|
w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) feedbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
data, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var feedback shared.Feedback
|
||||||
|
err = json.Unmarshal(data, &feedback)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
|
||||||
|
}
|
||||||
|
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
|
||||||
|
err = s.db.FeedbackCreate(r.Context(), &feedback)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
if s.statsd != nil {
|
||||||
|
s.statsd.Incr("hishtory.uninstall", []string{}, 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) healthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if s.isProductionEnvironment {
|
||||||
|
encHistoryEntryCount, err := s.db.CountHistoryEntries(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
if encHistoryEntryCount < 1000 {
|
||||||
|
panic("Suspiciously few enc history entries!")
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCount, err := s.db.CountAllDevices(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
if deviceCount < 100 {
|
||||||
|
panic("Suspiciously few devices!")
|
||||||
|
}
|
||||||
|
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
|
||||||
|
err = s.db.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{
|
||||||
|
EncryptedData: []byte("data"),
|
||||||
|
Nonce: []byte("nonce"),
|
||||||
|
DeviceId: "healthcheck_device_id",
|
||||||
|
UserId: "healthcheck_user_id",
|
||||||
|
Date: time.Now(),
|
||||||
|
EncryptedId: "healthcheck_enc_id",
|
||||||
|
ReadCount: 10000,
|
||||||
|
})
|
||||||
|
checkGormError(err)
|
||||||
|
} else {
|
||||||
|
err := s.db.Ping()
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("failed to ping DB: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) usageStatsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
usageData, err := s.db.UsageDataStats(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("db.UsageDataStats: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs")
|
||||||
|
tbl.WithWriter(w)
|
||||||
|
for _, data := range usageData {
|
||||||
|
versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "")
|
||||||
|
lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "")
|
||||||
|
tbl.AddRow(
|
||||||
|
data.RegistrationDate.Format(shared.DateOnly),
|
||||||
|
data.NumDevices,
|
||||||
|
data.NumEntries,
|
||||||
|
data.NumQueries,
|
||||||
|
data.LastUsedDate.Format(shared.DateOnly),
|
||||||
|
lastQueryStr,
|
||||||
|
versions,
|
||||||
|
data.IpAddresses,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
tbl.Print()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) statsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
numDevices, err := s.db.CountAllDevices(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
numEntriesProcessed, err := s.db.UsageDataTotal(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
numDbEntries, err := s.db.CountHistoryEntries(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
oneWeek := time.Hour * 24 * 7
|
||||||
|
weeklyActiveInstalls, err := s.db.CountActiveInstalls(r.Context(), oneWeek)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
weeklyQueryUsers, err := s.db.CountQueryUsers(r.Context(), oneWeek)
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
lastRegistration, err := s.db.DateOfLastRegistration(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices)
|
||||||
|
_, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed)
|
||||||
|
_, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
|
||||||
|
_, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
|
||||||
|
_, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
|
||||||
|
_, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Host == "api.hishtory.dev" || s.isProductionEnvironment {
|
||||||
|
panic("refusing to wipe the DB for prod")
|
||||||
|
}
|
||||||
|
if !s.isTestEnvironment {
|
||||||
|
panic("refusing to wipe the DB non-test environment")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.db.Unsafe_DeleteAllHistoryEntries(r.Context())
|
||||||
|
checkGormError(err)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Length", "0")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
stats, err := s.db.Stats()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error {
|
||||||
|
var usageData []shared.UsageData
|
||||||
|
usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err)
|
||||||
|
}
|
||||||
|
if len(usageData) == 0 {
|
||||||
|
err := s.db.CreateUsageData(
|
||||||
|
ctx,
|
||||||
|
&shared.UsageData{
|
||||||
|
UserId: userId,
|
||||||
|
DeviceId: deviceId,
|
||||||
|
LastUsed: time.Now(),
|
||||||
|
NumEntriesHandled: numEntriesHandled,
|
||||||
|
Version: version,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataCreate: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
usage := usageData[0]
|
||||||
|
|
||||||
|
if err := s.db.UpdateUsageData(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataUpdate: %w", err)
|
||||||
|
}
|
||||||
|
if numEntriesHandled > 0 {
|
||||||
|
if err := s.db.UpdateUsageDataForNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if usage.Version != version {
|
||||||
|
if err := s.db.UpdateUsageDataClientVersion(ctx, userId, deviceId, version); err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataUpdateVersion: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isQuery {
|
||||||
|
if err := s.db.UpdateUsageDataNumberQueries(ctx, userId, deviceId); err != nil {
|
||||||
|
return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
93
internal/server/util.go
Normal file
93
internal/server/util.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
|
||||||
|
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||||
|
"gopkg.in/DataDog/dd-trace-go.v1/profiler"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
pprofhttp "net/http/pprof"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getMaximumNumberOfAllowedUsers() int {
|
||||||
|
maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS")
|
||||||
|
if maxNumUsersStr == "" {
|
||||||
|
return math.MaxInt
|
||||||
|
}
|
||||||
|
maxNumUsers, err := strconv.Atoi(maxNumUsersStr)
|
||||||
|
if err != nil {
|
||||||
|
return math.MaxInt
|
||||||
|
}
|
||||||
|
return maxNumUsers
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureObservability(mux *httptrace.ServeMux, releaseVersion string) func() {
|
||||||
|
// Profiler
|
||||||
|
err := profiler.Start(
|
||||||
|
profiler.WithService("hishtory-api"),
|
||||||
|
profiler.WithVersion(releaseVersion),
|
||||||
|
profiler.WithAPIKey(os.Getenv("DD_API_KEY")),
|
||||||
|
profiler.WithUDS("/var/run/datadog/apm.socket"),
|
||||||
|
profiler.WithProfileTypes(
|
||||||
|
profiler.CPUProfile,
|
||||||
|
profiler.HeapProfile,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start DataDog profiler: %v\n", err)
|
||||||
|
}
|
||||||
|
// Tracer
|
||||||
|
tracer.Start(
|
||||||
|
tracer.WithRuntimeMetrics(),
|
||||||
|
tracer.WithService("hishtory-api"),
|
||||||
|
tracer.WithUDS("/var/run/datadog/apm.socket"),
|
||||||
|
)
|
||||||
|
// TODO: should this be here?
|
||||||
|
defer tracer.Stop()
|
||||||
|
|
||||||
|
// Pprof
|
||||||
|
mux.HandleFunc("/debug/pprof/", pprofhttp.Index)
|
||||||
|
mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline)
|
||||||
|
mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile)
|
||||||
|
mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol)
|
||||||
|
mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace)
|
||||||
|
|
||||||
|
// Func to stop all of the above
|
||||||
|
return func() {
|
||||||
|
profiler.Stop()
|
||||||
|
tracer.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHishtoryVersion(r *http.Request) string {
|
||||||
|
return r.Header.Get("X-Hishtory-Version")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRemoteAddr(r *http.Request) string {
|
||||||
|
addr, ok := r.Header["X-Real-Ip"]
|
||||||
|
if !ok || len(addr) == 0 {
|
||||||
|
return "UnknownIp"
|
||||||
|
}
|
||||||
|
return addr[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRequiredQueryParam(r *http.Request, queryParam string) string {
|
||||||
|
val := r.URL.Query().Get(queryParam)
|
||||||
|
if val == "" {
|
||||||
|
panic(fmt.Sprintf("request to %s is missing required query param=%#v", r.URL, queryParam))
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkGormError(err error) {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, filename, line, _ := runtime.Caller(1)
|
||||||
|
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err))
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user