isolate all server handlers into a single struct, without using global variables

This commit is contained in:
Sergio Moura 2023-09-12 09:26:20 -04:00
parent 50c74e5881
commit 02b1e8287d
4 changed files with 748 additions and 620 deletions

View File

@ -0,0 +1,81 @@
package main
import (
"fmt"
"github.com/DataDog/datadog-go/statsd"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"net/http"
"reflect"
"runtime"
"strings"
"time"
)
type loggedResponseData struct {
size int
}
type loggingResponseWriter struct {
http.ResponseWriter
responseData *loggedResponseData
}
func (r *loggingResponseWriter) Write(b []byte) (int, error) {
size, err := r.ResponseWriter.Write(b)
r.responseData.size += size
return size, err
}
func (r *loggingResponseWriter) WriteHeader(statusCode int) {
r.ResponseWriter.WriteHeader(statusCode)
}
func getFunctionName(temp interface{}) string {
strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".")
return strs[len(strs)-1]
}
func byteCountToString(b int) string {
const unit = 1000
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
}
type Middleware func(http.HandlerFunc) http.Handler
func withLogging(s *statsd.Client) Middleware {
return func(h http.HandlerFunc) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
var responseData loggedResponseData
lrw := loggingResponseWriter{
ResponseWriter: rw,
responseData: &responseData,
}
start := time.Now()
span, ctx := tracer.StartSpanFromContext(
r.Context(),
getFunctionName(h),
tracer.SpanType(ext.SpanTypeSQL),
tracer.ServiceName("hishtory-api"),
)
defer span.Finish()
h.ServeHTTP(&lrw, r.WithContext(ctx))
duration := time.Since(start)
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
if s != nil {
s.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0)
s.Incr("hishtory.request", []string{}, 1.0)
}
})
}
}

View File

@ -4,35 +4,27 @@ import (
"context"
"encoding/json"
"fmt"
"html"
"io"
"log"
"math"
"net/http"
"os"
"reflect"
"runtime"
"strconv"
"strings"
"time"
pprofhttp "net/http/pprof"
"github.com/DataDog/datadog-go/statsd"
"github.com/ddworken/hishtory/internal/database"
"github.com/ddworken/hishtory/shared"
_ "github.com/lib/pq"
"github.com/rodaine/table"
httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/profiler"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable"
PostgresDb = "postgresql://postgres:%s@postgres:5432/hishtory?sslmode=disable"
StatsdSocket = "unix:///var/run/datadog/dsd.socket"
)
var (
@ -53,195 +45,6 @@ func getHishtoryVersion(r *http.Request) string {
return r.Header.Get("X-Hishtory-Version")
}
func updateUsageData(r *http.Request, userId, deviceId string, numEntriesHandled int, isQuery bool) error {
var usageData []shared.UsageData
usageData, err := GLOBAL_DB.UsageDataFindByUserAndDevice(r.Context(), userId, deviceId)
if err != nil {
return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err)
}
if len(usageData) == 0 {
err := GLOBAL_DB.CreateUsageData(
r.Context(),
&shared.UsageData{
UserId: userId,
DeviceId: deviceId,
LastUsed: time.Now(),
NumEntriesHandled: numEntriesHandled,
Version: getHishtoryVersion(r),
},
)
if err != nil {
return fmt.Errorf("db.CreateUsageData: %w", err)
}
} else {
usage := usageData[0]
if err := GLOBAL_DB.UpdateUsageData(r.Context(), userId, deviceId, time.Now(), getRemoteAddr(r)); err != nil {
return fmt.Errorf("db.UpdateUsageData: %w", err)
}
if numEntriesHandled > 0 {
if err := GLOBAL_DB.UpdateUsageDataForNumEntriesHandled(r.Context(), userId, deviceId, numEntriesHandled); err != nil {
return fmt.Errorf("db.UpdateUsageDataForNumEntriesHandled: %w", err)
}
}
if usage.Version != getHishtoryVersion(r) {
if err := GLOBAL_DB.UpdateUsageDataClientVersion(r.Context(), userId, deviceId, getHishtoryVersion(r)); err != nil {
return fmt.Errorf("db.UpdateUsageDataClientVersion: %w", err)
}
}
}
if isQuery {
if err := GLOBAL_DB.UpdateUsageDataNumberQueries(r.Context(), userId, deviceId); err != nil {
return fmt.Errorf("db.UpdateUsageDataNumberQueries: %w", err)
}
}
return nil
}
func usageStatsHandler(w http.ResponseWriter, r *http.Request) {
usageData, err := GLOBAL_DB.UsageDataStats(r.Context())
if err != nil {
panic(fmt.Errorf("db.UsageDataStats: %w", err))
}
tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs")
tbl.WithWriter(w)
for _, data := range usageData {
versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "")
lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "")
tbl.AddRow(
data.RegistrationDate.Format(shared.DateOnly),
data.NumDevices,
data.NumEntries,
data.NumQueries,
data.LastUsedDate.Format(shared.DateOnly),
lastQueryStr,
versions,
data.IpAddresses,
)
}
tbl.Print()
}
func statsHandler(w http.ResponseWriter, r *http.Request) {
numDevices, err := GLOBAL_DB.CountAllDevices(r.Context())
checkGormError(err, 0)
numEntriesProcessed, err := GLOBAL_DB.UsageDataTotal(r.Context())
checkGormError(err, 0)
numDbEntries, err := GLOBAL_DB.CountHistoryEntries(r.Context())
checkGormError(err, 0)
oneWeek := time.Hour * 24 * 7
weeklyActiveInstalls, err := GLOBAL_DB.CountActiveInstalls(r.Context(), oneWeek)
checkGormError(err, 0)
weeklyQueryUsers, err := GLOBAL_DB.CountQueryUsers(r.Context(), oneWeek)
checkGormError(err, 0)
lastRegistration, err := GLOBAL_DB.DateOfLastRegistration(r.Context())
checkGormError(err, 0)
_, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices)
_, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed)
_, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
_, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
_, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
_, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
}
func apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var entries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &entries)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries))
if len(entries) == 0 {
return
}
_ = updateUsageData(r, entries[0].UserId, entries[0].DeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false)
devices, err := GLOBAL_DB.DevicesForUser(r.Context(), entries[0].UserId)
checkGormError(err, 0)
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = GLOBAL_DB.AddHistoryEntriesForAllDevices(r.Context(), devices, entries)
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
}
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false)
historyEntries, err := GLOBAL_DB.AllHistoryEntriesForUser(r.Context(), userId)
checkGormError(err, 1)
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
}
func apiQueryHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, true)
// Delete any entries that match a pending deletion request
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
for _, request := range deletionRequests {
_, err := GLOBAL_DB.ApplyDeletionRequestsToBackend(r.Context(), request)
checkGormError(err, 0)
}
// Then retrieve
historyEntries, err := GLOBAL_DB.HistoryEntriesForDevice(r.Context(), deviceId, 5)
checkGormError(err, 0)
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
// blocking the entire response. This does have a potential race condition, but that is fine.
if isProductionEnvironment() {
go func() {
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId)
span.Finish(tracer.WithError(err))
}()
} else {
err := GLOBAL_DB.IncrementEntryReadCountsForDevice(ctx, deviceId)
if err != nil {
panic("failed to increment read counts")
}
}
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.query", []string{}, 1.0)
}
}
func getRemoteAddr(r *http.Request) string {
addr, ok := r.Header["X-Real-Ip"]
if !ok || len(addr) == 0 {
@ -250,191 +53,6 @@ func getRemoteAddr(r *http.Request) string {
return addr[0]
}
func apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
numDistinctUsers, err := GLOBAL_DB.DistinctUsers(r.Context())
if err != nil {
panic(fmt.Errorf("db.DistinctUsers: %w", err))
}
if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) {
panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers()))
}
}
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
existingDevicesCount, err := GLOBAL_DB.CountDevicesForUser(r.Context(), userId)
checkGormError(err, 0)
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
if err := GLOBAL_DB.CreateDevice(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil {
checkGormError(err, 0)
}
if existingDevicesCount > 0 {
err := GLOBAL_DB.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
checkGormError(err, 0)
}
_ = updateUsageData(r, userId, deviceId /* numEntriesHandled = */, 0 /* isQuery = */, false)
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.register", []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
dumpRequests, err := GLOBAL_DB.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
}
func apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var entries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &entries)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
// sanity check
for _, entry := range entries {
entry.DeviceId = requestingDeviceId
if entry.UserId != userId {
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
}
}
err = GLOBAL_DB.AddHistoryEntries(r.Context(), entries...)
checkGormError(err, 0)
err = GLOBAL_DB.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
checkGormError(err, 0)
_ = updateUsageData(r, userId, srcDeviceId /* numEntriesHandled = */, len(entries) /* isQuery = */, false)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func apiBannerHandler(w http.ResponseWriter, r *http.Request) {
commitHash := getRequiredQueryParam(r, "commit_hash")
deviceId := getRequiredQueryParam(r, "device_id")
forcedBanner := r.URL.Query().Get("forced_banner")
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
if getHishtoryVersion(r) == "v0.160" {
w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory."))
return
}
w.Write([]byte(html.EscapeString(forcedBanner)))
}
func getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount
err := GLOBAL_DB.DeletionRequestInc(r.Context(), userId, deviceId)
checkGormError(err, 0)
// Return all the deletion requests
deletionRequests, err := GLOBAL_DB.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
}
func addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var request shared.DeletionRequest
if err := json.Unmarshal(data, &request); err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
request.ReadCount = 0
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
err = GLOBAL_DB.DeletionRequestCreate(r.Context(), &request)
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func healthCheckHandler(w http.ResponseWriter, r *http.Request) {
if isProductionEnvironment() {
encHistoryEntryCount, err := GLOBAL_DB.CountHistoryEntries(r.Context())
checkGormError(err, 0)
if encHistoryEntryCount < 1000 {
panic("Suspiciously few enc history entries!")
}
deviceCount, err := GLOBAL_DB.CountAllDevices(r.Context())
checkGormError(err, 0)
if deviceCount < 100 {
panic("Suspiciously few devices!")
}
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
err = GLOBAL_DB.AddHistoryEntries(r.Context(), &shared.EncHistoryEntry{
EncryptedData: []byte("data"),
Nonce: []byte("nonce"),
DeviceId: "healthcheck_device_id",
UserId: "healthcheck_user_id",
Date: time.Now(),
EncryptedId: "healthcheck_enc_id",
ReadCount: 10000,
})
checkGormError(err, 0)
} else {
err := GLOBAL_DB.Ping()
if err != nil {
panic(fmt.Errorf("failed to ping DB: %w", err))
}
}
w.Write([]byte("OK"))
}
func wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
panic("refusing to wipe the DB for prod")
}
if !isTestEnvironment() {
panic("refusing to wipe the DB non-test environment")
}
err := GLOBAL_DB.Unsafe_DeleteAllHistoryEntries(r.Context())
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
stats, err := GLOBAL_DB.Stats()
if err != nil {
panic(err)
}
_, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections)
}
func isTestEnvironment() bool {
return os.Getenv("HISHTORY_TEST") != ""
}
@ -540,16 +158,6 @@ func runBackgroundJobs(ctx context.Context) {
}
}
func triggerCronHandler(w http.ResponseWriter, r *http.Request) {
err := cron(r.Context())
if err != nil {
panic(err)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
type releaseInfo struct {
Name string `json:"name"`
}
@ -674,200 +282,22 @@ func buildUpdateInfo(version string) shared.UpdateInfo {
}
}
func apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
updateInfo := buildUpdateInfo(ReleaseVersion)
resp, err := json.Marshal(updateInfo)
if err != nil {
panic(err)
}
w.Write(resp)
}
func slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
// returns "OK" unless there is a current SLSA bug
v := getHishtoryVersion(r)
if !strings.Contains(v, "v0.") {
w.Write([]byte("OK"))
return
}
vNum, err := strconv.Atoi(strings.Split(v, ".")[1])
if err != nil {
w.Write([]byte("OK"))
return
}
if vNum < 159 {
w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163"))
return
}
w.Write([]byte("OK"))
}
func feedbackHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var feedback shared.Feedback
err = json.Unmarshal(data, &feedback)
if err != nil {
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
}
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
err = GLOBAL_DB.FeedbackCreate(r.Context(), &feedback)
checkGormError(err, 0)
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Incr("hishtory.uninstall", []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
type loggedResponseData struct {
size int
}
type loggingResponseWriter struct {
http.ResponseWriter
responseData *loggedResponseData
}
func (r *loggingResponseWriter) Write(b []byte) (int, error) {
size, err := r.ResponseWriter.Write(b)
r.responseData.size += size
return size, err
}
func (r *loggingResponseWriter) WriteHeader(statusCode int) {
r.ResponseWriter.WriteHeader(statusCode)
}
func getFunctionName(temp interface{}) string {
strs := strings.Split((runtime.FuncForPC(reflect.ValueOf(temp).Pointer()).Name()), ".")
return strs[len(strs)-1]
}
func withLogging(h http.HandlerFunc) http.Handler {
logFn := func(rw http.ResponseWriter, r *http.Request) {
var responseData loggedResponseData
lrw := loggingResponseWriter{
ResponseWriter: rw,
responseData: &responseData,
}
start := time.Now()
span, ctx := tracer.StartSpanFromContext(
r.Context(),
getFunctionName(h),
tracer.SpanType(ext.SpanTypeSQL),
tracer.ServiceName("hishtory-api"),
)
defer span.Finish()
h(&lrw, r.WithContext(ctx))
duration := time.Since(start)
fmt.Printf("%s %s %#v %s %s %s\n", getRemoteAddr(r), r.Method, r.RequestURI, getHishtoryVersion(r), duration.String(), byteCountToString(responseData.size))
if GLOBAL_STATSD != nil {
GLOBAL_STATSD.Distribution("hishtory.request_duration", float64(duration.Microseconds())/1_000, []string{"HANDLER=" + getFunctionName(h)}, 1.0)
GLOBAL_STATSD.Incr("hishtory.request", []string{}, 1.0)
}
}
return http.HandlerFunc(logFn)
}
func byteCountToString(b int) string {
const unit = 1000
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "kMG"[exp])
}
func configureObservability(mux *httptrace.ServeMux) func() {
// Profiler
err := profiler.Start(
profiler.WithService("hishtory-api"),
profiler.WithVersion(ReleaseVersion),
profiler.WithAPIKey(os.Getenv("DD_API_KEY")),
profiler.WithUDS("/var/run/datadog/apm.socket"),
profiler.WithProfileTypes(
profiler.CPUProfile,
profiler.HeapProfile,
),
)
if err != nil {
fmt.Printf("Failed to start DataDog profiler: %v\n", err)
}
// Tracer
tracer.Start(
tracer.WithRuntimeMetrics(),
tracer.WithService("hishtory-api"),
tracer.WithUDS("/var/run/datadog/apm.socket"),
)
defer tracer.Stop()
// Stats
ddStats, err := statsd.New("unix:///var/run/datadog/dsd.socket")
func main() {
s, err := statsd.New(StatsdSocket)
if err != nil {
fmt.Printf("Failed to start DataDog statsd: %v\n", err)
}
GLOBAL_STATSD = ddStats
// Pprof
mux.HandleFunc("/debug/pprof/", pprofhttp.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace)
// Func to stop all of the above
return func() {
profiler.Stop()
tracer.Stop()
// TODO: remove this global once we have a better way to pass it around
GLOBAL_STATSD = s
srv := NewServer(GLOBAL_DB, WithStatsd(s))
if err := srv.Run(context.Background(), ":8080"); err != nil {
panic(err)
}
}
func main() {
mux := httptrace.NewServeMux()
if isProductionEnvironment() {
defer configureObservability(mux)()
go func() {
if err := GLOBAL_DB.DeepClean(context.Background()); err != nil {
panic(err)
}
}()
}
mux.Handle("/api/v1/submit", withLogging(apiSubmitHandler))
mux.Handle("/api/v1/get-dump-requests", withLogging(apiGetPendingDumpRequestsHandler))
mux.Handle("/api/v1/submit-dump", withLogging(apiSubmitDumpHandler))
mux.Handle("/api/v1/query", withLogging(apiQueryHandler))
mux.Handle("/api/v1/bootstrap", withLogging(apiBootstrapHandler))
mux.Handle("/api/v1/register", withLogging(apiRegisterHandler))
mux.Handle("/api/v1/banner", withLogging(apiBannerHandler))
mux.Handle("/api/v1/download", withLogging(apiDownloadHandler))
mux.Handle("/api/v1/trigger-cron", withLogging(triggerCronHandler))
mux.Handle("/api/v1/get-deletion-requests", withLogging(getDeletionRequestsHandler))
mux.Handle("/api/v1/add-deletion-request", withLogging(addDeletionRequestHandler))
mux.Handle("/api/v1/slsa-status", withLogging(slsaStatusHandler))
mux.Handle("/api/v1/feedback", withLogging(feedbackHandler))
mux.Handle("/healthcheck", withLogging(healthCheckHandler))
mux.Handle("/internal/api/v1/usage-stats", withLogging(usageStatsHandler))
mux.Handle("/internal/api/v1/stats", withLogging(statsHandler))
if isTestEnvironment() {
mux.Handle("/api/v1/wipe-db-entries", withLogging(wipeDbEntriesHandler))
mux.Handle("/api/v1/get-num-connections", withLogging(getNumConnectionsHandler))
}
fmt.Println("Listening on localhost:8080")
log.Fatal(http.ListenAndServe(":8080", mux))
}
func checkGormResult(result *gorm.DB) {
checkGormError(result.Error, 1)
}

View File

@ -24,6 +24,7 @@ import (
func TestESubmitThenQuery(t *testing.T) {
// Set up
InitDB()
s := NewServer(GLOBAL_DB)
// Register a few devices
userId := data.UserId("key")
@ -32,11 +33,11 @@ func TestESubmitThenQuery(t *testing.T) {
otherUser := data.UserId("otherkey")
otherDev := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev+"&user_id="+otherUser, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
// Submit a few entries for different devices
entry := testutils.MakeFakeHistoryEntry("ls ~/")
@ -45,12 +46,12 @@ func TestESubmitThenQuery(t *testing.T) {
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
// Query for device id 1
w := httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res := w.Result()
defer res.Body.Close()
respBody, err := io.ReadAll(res.Body)
@ -79,7 +80,7 @@ func TestESubmitThenQuery(t *testing.T) {
// Same for device id 2
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -107,7 +108,7 @@ func TestESubmitThenQuery(t *testing.T) {
// Bootstrap handler should return 2 entries, one for each device
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?user_id="+data.UserId("key")+"&device_id="+devId1, nil)
apiBootstrapHandler(w, searchReq)
s.apiBootstrapHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -124,6 +125,7 @@ func TestESubmitThenQuery(t *testing.T) {
func TestDumpRequestAndResponse(t *testing.T) {
// Set up
InitDB()
s := NewServer(GLOBAL_DB)
// Register a first device for two different users
userId := data.UserId("dkey")
@ -133,17 +135,17 @@ func TestDumpRequestAndResponse(t *testing.T) {
otherDev1 := uuid.Must(uuid.NewRandom()).String()
otherDev2 := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
// Query for dump requests, there should be one for userId
w := httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
res := w.Result()
defer res.Body.Close()
respBody, err := io.ReadAll(res.Body)
@ -163,7 +165,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
// And one for otherUser
w = httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -183,7 +185,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
// And none if we query for a user ID that doesn't exit
w = httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=foo&device_id=bar", nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -193,7 +195,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
// And none for a missing user ID
w = httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id=%20&device_id=%20", nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -211,11 +213,11 @@ func TestDumpRequestAndResponse(t *testing.T) {
reqBody, err := json.Marshal([]shared.EncHistoryEntry{entry1, entry2})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/?user_id="+userId+"&requesting_device_id="+devId2+"&source_device_id="+devId1, bytes.NewReader(reqBody))
apiSubmitDumpHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitDumpHandler(httptest.NewRecorder(), submitReq)
// Check that the dump request is no longer there for userId for either device ID
w = httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId1, nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -226,7 +228,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
w = httptest.NewRecorder()
// The other user
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+userId+"&device_id="+devId2, nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -236,7 +238,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
// But it is there for the other user
w = httptest.NewRecorder()
apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
s.apiGetPendingDumpRequestsHandler(w, httptest.NewRequest(http.MethodGet, "/?user_id="+otherUser+"&device_id="+otherDev1, nil))
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -257,7 +259,7 @@ func TestDumpRequestAndResponse(t *testing.T) {
// And finally, query to ensure that the dumped entries are in the DB
w = httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -323,6 +325,7 @@ func TestUpdateReleaseVersion(t *testing.T) {
func TestDeletionRequests(t *testing.T) {
// Set up
InitDB()
s := NewServer(GLOBAL_DB)
// Register two devices for two different users
userId := data.UserId("dkey")
@ -332,13 +335,13 @@ func TestDeletionRequests(t *testing.T) {
otherDev1 := uuid.Must(uuid.NewRandom()).String()
otherDev2 := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId2+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev2+"&user_id="+otherUser, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
// Add an entry for user1
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
@ -348,7 +351,7 @@ func TestDeletionRequests(t *testing.T) {
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
// And another entry for user1
entry2 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
@ -358,7 +361,7 @@ func TestDeletionRequests(t *testing.T) {
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
// And an entry for user2 that has the same timestamp as the previous entry
entry3 := testutils.MakeFakeHistoryEntry("ls /foo/bar")
@ -369,12 +372,12 @@ func TestDeletionRequests(t *testing.T) {
reqBody, err = json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
// Query for device id 1
w := httptest.NewRecorder()
searchReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res := w.Result()
defer res.Body.Close()
respBody, err := io.ReadAll(res.Body)
@ -413,13 +416,13 @@ func TestDeletionRequests(t *testing.T) {
reqBody, err = json.Marshal(delReq)
testutils.Check(t, err)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
addDeletionRequestHandler(httptest.NewRecorder(), req)
s.addDeletionRequestHandler(httptest.NewRecorder(), req)
// Query again for device id 1 and get a single result
time.Sleep(10 * time.Millisecond)
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -447,7 +450,7 @@ func TestDeletionRequests(t *testing.T) {
// Query for user 2
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+otherDev1+"&user_id="+otherUser, nil)
apiQueryHandler(w, searchReq)
s.apiQueryHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -475,7 +478,7 @@ func TestDeletionRequests(t *testing.T) {
// Query for deletion requests
w = httptest.NewRecorder()
searchReq = httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
getDeletionRequestsHandler(w, searchReq)
s.getDeletionRequestsHandler(w, searchReq)
res = w.Result()
defer res.Body.Close()
respBody, err = io.ReadAll(res.Body)
@ -504,8 +507,9 @@ func TestDeletionRequests(t *testing.T) {
}
func TestHealthcheck(t *testing.T) {
s := NewServer(GLOBAL_DB)
w := httptest.NewRecorder()
healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil))
if w.Code != 200 {
t.Fatalf("expected 200 resp code for healthCheckHandler")
}
@ -524,6 +528,7 @@ func TestHealthcheck(t *testing.T) {
func TestLimitRegistrations(t *testing.T) {
// Set up
InitDB()
s := NewServer(GLOBAL_DB)
checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries"))
checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices"))
defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")()
@ -531,28 +536,29 @@ func TestLimitRegistrations(t *testing.T) {
// Register three devices across two users
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
// And this next one should fail since it is a new user
defer func() { _ = recover() }()
deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
t.Errorf("expected panic")
}
func TestCleanDatabaseNoErrors(t *testing.T) {
// Init
InitDB()
s := NewServer(GLOBAL_DB)
// Create a user and an entry
userId := data.UserId("dkey")
devId1 := uuid.Must(uuid.NewRandom()).String()
deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+devId1+"&user_id="+userId, nil)
apiRegisterHandler(httptest.NewRecorder(), deviceReq)
s.apiRegisterHandler(httptest.NewRecorder(), deviceReq)
entry1 := testutils.MakeFakeHistoryEntry("ls ~/")
entry1.DeviceId = devId1
encEntry, err := data.EncryptHistoryEntry("dkey", entry1)
@ -560,7 +566,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) {
reqBody, err := json.Marshal([]shared.EncHistoryEntry{encEntry})
testutils.Check(t, err)
submitReq := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(reqBody))
apiSubmitHandler(httptest.NewRecorder(), submitReq)
s.apiSubmitHandler(httptest.NewRecorder(), submitReq)
// Call cleanDatabase and just check that there are no panics
testutils.Check(t, GLOBAL_DB.Clean(context.TODO()))

611
backend/server/srv.go Normal file
View File

@ -0,0 +1,611 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"html"
"io"
"math"
"net/http"
pprofhttp "net/http/pprof"
"os"
"strconv"
"strings"
"time"
"github.com/DataDog/datadog-go/statsd"
"github.com/ddworken/hishtory/internal/database"
"github.com/ddworken/hishtory/shared"
"github.com/rodaine/table"
httptrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/net/http"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/DataDog/dd-trace-go.v1/profiler"
)
type Srv struct {
db *database.DB
statsd *statsd.Client
}
type ServerOption func(*Srv)
func WithStatsd(statsd *statsd.Client) ServerOption {
return func(s *Srv) {
s.statsd = statsd
}
}
func NewServer(db *database.DB, options ...ServerOption) *Srv {
srv := Srv{db: db}
for _, option := range options {
option(&srv)
}
return &srv
}
func (s *Srv) Run(ctx context.Context, addr string) error {
mux := httptrace.NewServeMux()
if isProductionEnvironment() {
defer configureObservability(mux)()
go func() {
if err := s.db.DeepClean(ctx); err != nil {
panic(err)
}
}()
}
loggerMiddleware := withLogging(s.statsd)
mux.Handle("/api/v1/submit", loggerMiddleware(s.apiSubmitHandler))
mux.Handle("/api/v1/get-dump-requests", loggerMiddleware(s.apiGetPendingDumpRequestsHandler))
mux.Handle("/api/v1/submit-dump", loggerMiddleware(s.apiSubmitDumpHandler))
mux.Handle("/api/v1/query", loggerMiddleware(s.apiQueryHandler))
mux.Handle("/api/v1/bootstrap", loggerMiddleware(s.apiBootstrapHandler))
mux.Handle("/api/v1/register", loggerMiddleware(s.apiRegisterHandler))
mux.Handle("/api/v1/banner", loggerMiddleware(s.apiBannerHandler))
mux.Handle("/api/v1/download", loggerMiddleware(s.apiDownloadHandler))
mux.Handle("/api/v1/trigger-cron", loggerMiddleware(s.triggerCronHandler))
mux.Handle("/api/v1/get-deletion-requests", loggerMiddleware(s.getDeletionRequestsHandler))
mux.Handle("/api/v1/add-deletion-request", loggerMiddleware(s.addDeletionRequestHandler))
mux.Handle("/api/v1/slsa-status", loggerMiddleware(s.slsaStatusHandler))
mux.Handle("/api/v1/feedback", loggerMiddleware(s.feedbackHandler))
mux.Handle("/healthcheck", loggerMiddleware(s.healthCheckHandler))
mux.Handle("/internal/api/v1/usage-stats", loggerMiddleware(s.usageStatsHandler))
mux.Handle("/internal/api/v1/stats", loggerMiddleware(s.statsHandler))
if isTestEnvironment() {
mux.Handle("/api/v1/wipe-db-entries", loggerMiddleware(s.wipeDbEntriesHandler))
mux.Handle("/api/v1/get-num-connections", loggerMiddleware(s.getNumConnectionsHandler))
}
httpServer := &http.Server{
Addr: addr,
Handler: mux,
}
fmt.Printf("Listening on %s\n", addr)
if err := httpServer.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("http.ListenAndServe: %w", err)
}
}
return nil
}
func (s *Srv) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var entries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &entries)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitHandler: received request containg %d EncHistoryEntry\n", len(entries))
if len(entries) == 0 {
return
}
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
remoteIPAddr := getRemoteAddr(r)
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, entries[0].UserId, entries[0].DeviceId, len(entries), false); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
devices, err := s.db.DevicesForUser(r.Context(), entries[0].UserId)
checkGormError(err, 0)
if len(devices) == 0 {
panic(fmt.Errorf("found no devices associated with user_id=%s, can't save history entry", entries[0].UserId))
}
fmt.Printf("apiSubmitHandler: Found %d devices\n", len(devices))
err = s.db.DeviceEntriesCreateChunk(r.Context(), devices, entries, 1000)
if err != nil {
panic(fmt.Errorf("failed to execute transaction to add entries to DB: %w", err))
}
if s.statsd != nil {
s.statsd.Count("hishtory.submit", int64(len(devices)), []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) apiBootstrapHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
remoteIPAddr := getRemoteAddr(r)
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
historyEntries, err := s.db.EncHistoryEntriesForUser(r.Context(), userId)
checkGormError(err, 1)
fmt.Printf("apiBootstrapHandler: Found %d entries\n", len(historyEntries))
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
}
func (s *Srv) apiQueryHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
remoteIPAddr := getRemoteAddr(r)
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, true); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
// Delete any entries that match a pending deletion request
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
for _, request := range deletionRequests {
_, err := s.db.ApplyDeletionRequestsToBackend(r.Context(), request)
checkGormError(err, 0)
}
// Then retrieve
historyEntries, err := s.db.EncHistoryEntriesForDevice(r.Context(), deviceId, 5)
checkGormError(err, 0)
fmt.Printf("apiQueryHandler: Found %d entries for %s\n", len(historyEntries), r.URL)
if err := json.NewEncoder(w).Encode(historyEntries); err != nil {
panic(err)
}
// And finally, kick off a background goroutine that will increment the read count. Doing it in the background avoids
// blocking the entire response. This does have a potential race condition, but that is fine.
if isProductionEnvironment() {
go func() {
span, ctx := tracer.StartSpanFromContext(ctx, "apiQueryHandler.incrementReadCount")
err := s.db.DeviceIncrementReadCounts(ctx, deviceId)
span.Finish(tracer.WithError(err))
}()
} else {
err := s.db.DeviceIncrementReadCounts(ctx, deviceId)
if err != nil {
panic("failed to increment read counts")
}
}
if s.statsd != nil {
s.statsd.Incr("hishtory.query", []string{}, 1.0)
}
}
func (s *Srv) apiSubmitDumpHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
srcDeviceId := getRequiredQueryParam(r, "source_device_id")
requestingDeviceId := getRequiredQueryParam(r, "requesting_device_id")
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var entries []*shared.EncHistoryEntry
err = json.Unmarshal(data, &entries)
if err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
fmt.Printf("apiSubmitDumpHandler: received request containg %d EncHistoryEntry\n", len(entries))
// sanity check
for _, entry := range entries {
entry.DeviceId = requestingDeviceId
if entry.UserId != userId {
panic(fmt.Errorf("batch contains an entry with UserId=%#v, when the query param contained the user_id=%#v", entry.UserId, userId))
}
}
err = s.db.EncHistoryCreateMulti(r.Context(), entries...)
checkGormError(err, 0)
err = s.db.DumpRequestDeleteForUserAndDevice(r.Context(), userId, requestingDeviceId)
checkGormError(err, 0)
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
remoteIPAddr := getRemoteAddr(r)
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, srcDeviceId, len(entries), false); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) apiBannerHandler(w http.ResponseWriter, r *http.Request) {
commitHash := getRequiredQueryParam(r, "commit_hash")
deviceId := getRequiredQueryParam(r, "device_id")
forcedBanner := r.URL.Query().Get("forced_banner")
fmt.Printf("apiBannerHandler: commit_hash=%#v, device_id=%#v, forced_banner=%#v\n", commitHash, deviceId, forcedBanner)
if getHishtoryVersion(r) == "v0.160" {
w.Write([]byte("Warning: hiSHtory v0.160 has a bug that slows down your shell! Please run `hishtory update` to upgrade hiSHtory."))
return
}
w.Write([]byte(html.EscapeString(forcedBanner)))
}
func (s *Srv) apiGetPendingDumpRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
var dumpRequests []*shared.DumpRequest
// Filter out ones requested by the hishtory instance that sent this request
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(dumpRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
}
func (s *Srv) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) {
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
// Increment the ReadCount
err := s.db.DeletionRequestInc(r.Context(), userId, deviceId)
checkGormError(err, 0)
// Return all the deletion requests
deletionRequests, err := s.db.DeletionRequestsForUserAndDevice(r.Context(), userId, deviceId)
checkGormError(err, 0)
if err := json.NewEncoder(w).Encode(deletionRequests); err != nil {
panic(fmt.Errorf("failed to JSON marshall the dump requests: %w", err))
}
}
func (s *Srv) addDeletionRequestHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var request shared.DeletionRequest
if err := json.Unmarshal(data, &request); err != nil {
panic(fmt.Sprintf("body=%#v, err=%v", data, err))
}
request.ReadCount = 0
fmt.Printf("addDeletionRequestHandler: received request containg %d messages to be deleted\n", len(request.Messages.Ids))
err = s.db.DeletionRequestCreate(r.Context(), &request)
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (_ *Srv) apiDownloadHandler(w http.ResponseWriter, r *http.Request) {
updateInfo := buildUpdateInfo(ReleaseVersion)
resp, err := json.Marshal(updateInfo)
if err != nil {
panic(err)
}
w.Write(resp)
}
func (s *Srv) apiRegisterHandler(w http.ResponseWriter, r *http.Request) {
if getMaximumNumberOfAllowedUsers() < math.MaxInt {
numDistinctUsers, err := s.db.DistinctUsers(r.Context())
if err != nil {
panic(fmt.Errorf("db.DistinctUsers: %w", err))
}
if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) {
panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers()))
}
}
userId := getRequiredQueryParam(r, "user_id")
deviceId := getRequiredQueryParam(r, "device_id")
existingDevicesCount, err := s.db.DevicesCountForUser(r.Context(), userId)
checkGormError(err, 0)
fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)
if err := s.db.DeviceCreate(r.Context(), &shared.Device{UserId: userId, DeviceId: deviceId, RegistrationIp: getRemoteAddr(r), RegistrationDate: time.Now()}); err != nil {
checkGormError(err, 0)
}
if existingDevicesCount > 0 {
err := s.db.DumpRequestCreate(r.Context(), &shared.DumpRequest{UserId: userId, RequestingDeviceId: deviceId, RequestTime: time.Now()})
checkGormError(err, 0)
}
// TODO: add these to the context in a middleware
version := getHishtoryVersion(r)
remoteIPAddr := getRemoteAddr(r)
if err := s.updateUsageData(r.Context(), version, remoteIPAddr, userId, deviceId, 0, false); err != nil {
fmt.Printf("updateUsageData: %v\n", err)
}
if s.statsd != nil {
s.statsd.Incr("hishtory.register", []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) triggerCronHandler(w http.ResponseWriter, r *http.Request) {
err := cron(r.Context())
if err != nil {
panic(err)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) slsaStatusHandler(w http.ResponseWriter, r *http.Request) {
// returns "OK" unless there is a current SLSA bug
v := getHishtoryVersion(r)
if !strings.Contains(v, "v0.") {
w.Write([]byte("OK"))
return
}
vNum, err := strconv.Atoi(strings.Split(v, ".")[1])
if err != nil {
w.Write([]byte("OK"))
return
}
if vNum < 159 {
w.Write([]byte("Sigstore deployed a broken change. See https://github.com/slsa-framework/slsa-github-generator/issues/1163"))
return
}
w.Write([]byte("OK"))
}
func (s *Srv) feedbackHandler(w http.ResponseWriter, r *http.Request) {
data, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
var feedback shared.Feedback
err = json.Unmarshal(data, &feedback)
if err != nil {
panic(fmt.Sprintf("feedbackHandler: body=%#v, err=%v", data, err))
}
fmt.Printf("feedbackHandler: received request containg feedback %#v\n", feedback)
err = s.db.FeedbackCreate(r.Context(), &feedback)
checkGormError(err, 0)
if s.statsd != nil {
s.statsd.Incr("hishtory.uninstall", []string{}, 1.0)
}
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) healthCheckHandler(w http.ResponseWriter, r *http.Request) {
if isProductionEnvironment() {
// Check that we have a reasonable looking set of devices/entries in the DB
//rows, err := s.db.Raw("SELECT true FROM enc_history_entries LIMIT 1 OFFSET 1000").Rows()
//if err != nil {
// panic(fmt.Sprintf("failed to count entries in DB: %v", err))
//}
//defer rows.Close()
//if !rows.Next() {
// panic("Suspiciously few enc history entries!")
//}
encHistoryEntryCount, err := s.db.EncHistoryEntryCount(r.Context())
checkGormError(err, 0)
if encHistoryEntryCount < 1000 {
panic("Suspiciously few enc history entries!")
}
deviceCount, err := s.db.DevicesCount(r.Context())
checkGormError(err, 0)
if deviceCount < 100 {
panic("Suspiciously few devices!")
}
// Check that we can write to the DB. This entry will get written and then eventually cleaned by the cron.
err = s.db.EncHistoryCreate(r.Context(), &shared.EncHistoryEntry{
EncryptedData: []byte("data"),
Nonce: []byte("nonce"),
DeviceId: "healthcheck_device_id",
UserId: "healthcheck_user_id",
Date: time.Now(),
EncryptedId: "healthcheck_enc_id",
ReadCount: 10000,
})
checkGormError(err, 0)
} else {
err := s.db.Ping()
if err != nil {
panic(fmt.Errorf("failed to ping DB: %w", err))
}
}
w.Write([]byte("OK"))
}
func (s *Srv) usageStatsHandler(w http.ResponseWriter, r *http.Request) {
usageData, err := s.db.UsageDataStats(r.Context())
if err != nil {
panic(fmt.Errorf("db.UsageDataStats: %w", err))
}
tbl := table.New("Registration Date", "Num Devices", "Num Entries", "Num Queries", "Last Active", "Last Query", "Versions", "IPs")
tbl.WithWriter(w)
for _, data := range usageData {
versions := strings.ReplaceAll(strings.ReplaceAll(data.Versions, "Unknown", ""), ", ", "")
lastQueryStr := strings.ReplaceAll(data.LastQueried.Format(shared.DateOnly), "1970-01-01", "")
tbl.AddRow(
data.RegistrationDate.Format(shared.DateOnly),
data.NumDevices,
data.NumEntries,
data.NumQueries,
data.LastUsedDate.Format(shared.DateOnly),
lastQueryStr,
versions,
data.IpAddresses,
)
}
tbl.Print()
}
func (s *Srv) statsHandler(w http.ResponseWriter, r *http.Request) {
numDevices, err := s.db.DevicesCount(r.Context())
checkGormError(err, 0)
numEntriesProcessed, err := s.db.UsageDataTotal(r.Context())
checkGormError(err, 0)
numDbEntries, err := s.db.EncHistoryEntryCount(r.Context())
checkGormError(err, 0)
oneWeek := time.Hour * 24 * 7
weeklyActiveInstalls, err := s.db.WeeklyActiveInstalls(r.Context(), oneWeek)
checkGormError(err, 0)
weeklyQueryUsers, err := s.db.WeeklyQueryUsers(r.Context(), oneWeek)
checkGormError(err, 0)
lastRegistration, err := s.db.LastRegistration(r.Context())
checkGormError(err, 0)
_, _ = fmt.Fprintf(w, "Num devices: %d\n", numDevices)
_, _ = fmt.Fprintf(w, "Num history entries processed: %d\n", numEntriesProcessed)
_, _ = fmt.Fprintf(w, "Num DB entries: %d\n", numDbEntries)
_, _ = fmt.Fprintf(w, "Weekly active installs: %d\n", weeklyActiveInstalls)
_, _ = fmt.Fprintf(w, "Weekly active queries: %d\n", weeklyQueryUsers)
_, _ = fmt.Fprintf(w, "Last registration: %s\n", lastRegistration)
}
func (s *Srv) wipeDbEntriesHandler(w http.ResponseWriter, r *http.Request) {
if r.Host == "api.hishtory.dev" || isProductionEnvironment() {
panic("refusing to wipe the DB for prod")
}
if !isTestEnvironment() {
panic("refusing to wipe the DB non-test environment")
}
err := s.db.EncHistoryClear(r.Context())
checkGormError(err, 0)
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
func (s *Srv) getNumConnectionsHandler(w http.ResponseWriter, r *http.Request) {
stats, err := s.db.Stats()
if err != nil {
panic(err)
}
_, _ = fmt.Fprintf(w, "%#v", stats.OpenConnections)
}
func (s *Srv) updateUsageData(ctx context.Context, version string, remoteAddr string, userId, deviceId string, numEntriesHandled int, isQuery bool) error {
var usageData []shared.UsageData
usageData, err := s.db.UsageDataFindByUserAndDevice(ctx, userId, deviceId)
if err != nil {
return fmt.Errorf("db.UsageDataFindByUserAndDevice: %w", err)
}
if len(usageData) == 0 {
err := s.db.UsageDataCreate(
ctx,
&shared.UsageData{
UserId: userId,
DeviceId: deviceId,
LastUsed: time.Now(),
NumEntriesHandled: numEntriesHandled,
Version: version,
},
)
if err != nil {
return fmt.Errorf("db.UsageDataCreate: %w", err)
}
} else {
usage := usageData[0]
if err := s.db.UsageDataUpdate(ctx, userId, deviceId, time.Now(), remoteAddr); err != nil {
return fmt.Errorf("db.UsageDataUpdate: %w", err)
}
if numEntriesHandled > 0 {
if err := s.db.UsageDataUpdateNumEntriesHandled(ctx, userId, deviceId, numEntriesHandled); err != nil {
return fmt.Errorf("db.UsageDataUpdateNumEntriesHandled: %w", err)
}
}
if usage.Version != version {
if err := s.db.UsageDataUpdateVersion(ctx, userId, deviceId, version); err != nil {
return fmt.Errorf("db.UsageDataUpdateVersion: %w", err)
}
}
}
if isQuery {
if err := s.db.UsageDataUpdateNumQueries(ctx, userId, deviceId); err != nil {
return fmt.Errorf("db.UsageDataUpdateNumQueries: %w", err)
}
}
return nil
}
func configureObservability(mux *httptrace.ServeMux) func() {
// Profiler
err := profiler.Start(
profiler.WithService("hishtory-api"),
profiler.WithVersion(ReleaseVersion),
profiler.WithAPIKey(os.Getenv("DD_API_KEY")),
profiler.WithUDS("/var/run/datadog/apm.socket"),
profiler.WithProfileTypes(
profiler.CPUProfile,
profiler.HeapProfile,
),
)
if err != nil {
fmt.Printf("Failed to start DataDog profiler: %v\n", err)
}
// Tracer
tracer.Start(
tracer.WithRuntimeMetrics(),
tracer.WithService("hishtory-api"),
tracer.WithUDS("/var/run/datadog/apm.socket"),
)
// TODO: should this be here?
defer tracer.Stop()
// Pprof
mux.HandleFunc("/debug/pprof/", pprofhttp.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprofhttp.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprofhttp.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprofhttp.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprofhttp.Trace)
// Func to stop all of the above
return func() {
profiler.Stop()
tracer.Stop()
}
}