Always include user and device ID in API request headers, so that they're available in all server-side handlers

This commit is contained in:
David Dworken 2023-10-14 10:52:35 -07:00
parent 54c3429bca
commit fca2b1441f
No known key found for this signature in database
12 changed files with 87 additions and 62 deletions

View File

@ -540,7 +540,7 @@ func installFromHead(t *testing.T, tester shellTester) (string, string) {
func installFromPrev(t *testing.T, tester shellTester) (string, string) {
defer testutils.BackupAndRestoreEnv("HISHTORY_FORCE_CLIENT_VERSION")()
dd, err := lib.GetDownloadData()
dd, err := lib.GetDownloadData(makeTestOnlyContextWithFakeConfig())
require.NoError(t, err)
pv, err := shared.ParseVersionString(dd.Version)
require.NoError(t, err)
@ -553,7 +553,7 @@ func installFromPrev(t *testing.T, tester shellTester) (string, string) {
}
func updateToRelease(t *testing.T, tester shellTester) string {
dd, err := lib.GetDownloadData()
dd, err := lib.GetDownloadData(makeTestOnlyContextWithFakeConfig())
require.NoError(t, err)
// Update
@ -962,9 +962,10 @@ func testRequestAndReceiveDbDump(t *testing.T, tester shellTester) {
secretKey := installHishtory(t, tester, "")
// Confirm there are no pending dump requests
config := hctx.GetConf(hctx.MakeContext())
ctx := hctx.MakeContext()
config := hctx.GetConf(ctx)
deviceId1 := config.DeviceId
respBytes, err := lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
respBytes, err := lib.ApiGet(ctx, "/api/v1/get-dump-requests?user_id="+data.UserId(secretKey)+"&device_id="+deviceId1)
resp := strings.TrimSpace(string(respBytes))
require.NoError(t, err, "failed to get pending dump requests")
require.Equalf(t, "[]", resp, "there are pending dump requests! user_id=%#v, resp=%#v", data.UserId(secretKey), resp)
@ -988,14 +989,14 @@ echo other`)
restoreFirstInstallation := testutils.BackupAndRestoreWithId(t, "-install1")
// Wipe the DB to simulate entries getting deleted because they've already been read and expired
_, err = lib.ApiGet("/api/v1/wipe-db-entries")
_, err = lib.ApiGet(ctx, "/api/v1/wipe-db-entries")
require.NoError(t, err, "failed to wipe the remote DB")
// Install a new one (with the same secret key but a diff device id)
installHishtory(t, tester, secretKey)
// Confirm there is now a pending dump requests that the first device should respond to
respBytes, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
respBytes, err = lib.ApiGet(ctx, "/api/v1/get-dump-requests?user_id="+data.UserId(secretKey)+"&device_id="+deviceId1)
resp = strings.TrimSpace(string(respBytes))
require.NoError(t, err, "failed to get pending dump requests")
require.NotEqualf(t, "[]", resp, "There are no pending dump requests! user_id=%#v, resp=%#v", data.UserId(secretKey), string(resp))
@ -1025,7 +1026,7 @@ echo other`)
}
// Confirm there are no pending dump requests for the first device
respBytes, err = lib.ApiGet("/api/v1/get-dump-requests?user_id=" + data.UserId(secretKey) + "&device_id=" + deviceId1)
respBytes, err = lib.ApiGet(ctx, "/api/v1/get-dump-requests?user_id="+data.UserId(secretKey)+"&device_id="+deviceId1)
resp = strings.TrimSpace(string(respBytes))
require.NoError(t, err, "failed to get pending dump requests")
require.Equalf(t, "[]", resp, "There are pending dump requests! user_id=%#v, resp=%#v", data.UserId(secretKey), string(resp))
@ -1090,7 +1091,7 @@ func testInstallViaPythonScriptChild(t *testing.T, tester shellTester) {
userSecret := matches[1]
// Test the status subcommand
downloadData, err := lib.GetDownloadData()
downloadData, err := lib.GetDownloadData(makeTestOnlyContextWithFakeConfig())
require.NoError(t, err)
out = tester.RunInteractiveShell(t, `hishtory status`)
expectedOut := fmt.Sprintf("hiSHtory: %s\nEnabled: true\nSecret Key: %s\nCommit Hash: ", downloadData.Version, userSecret)

View File

@ -115,7 +115,7 @@ var uninstallCmd = &cobra.Command{
}
reqBody, err := json.Marshal(feedback)
lib.CheckFatalError(err)
_, _ = lib.ApiPost("/api/v1/feedback", "application/json", reqBody)
_, _ = lib.ApiPost(ctx, "/api/v1/feedback", "application/json", reqBody)
lib.CheckFatalError(uninstall(ctx))
},
}

View File

@ -65,7 +65,7 @@ func export(ctx context.Context, query string) {
db := hctx.GetDb(ctx)
err := lib.RetrieveAdditionalEntriesFromRemote(ctx)
if err != nil {
if lib.IsOfflineError(err) {
if lib.IsOfflineError(ctx, err) {
fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!")
} else {
lib.CheckFatalError(err)
@ -82,7 +82,7 @@ func query(ctx context.Context, query string) {
db := hctx.GetDb(ctx)
err := lib.RetrieveAdditionalEntriesFromRemote(ctx)
if err != nil {
if lib.IsOfflineError(err) {
if lib.IsOfflineError(ctx, err) {
fmt.Println("Warning: hishtory is offline so this may be missing recent results from your other machines!")
} else {
lib.CheckFatalError(err)
@ -97,7 +97,7 @@ func query(ctx context.Context, query string) {
func displayBannerIfSet(ctx context.Context) error {
respBody, err := lib.GetBanner(ctx)
if lib.IsOfflineError(err) {
if lib.IsOfflineError(ctx, err) {
return nil
}
if err != nil {

View File

@ -90,7 +90,7 @@ func deleteOnRemoteInstances(ctx context.Context, historyEntries []*data.History
shared.MessageIdentifier{DeviceId: entry.DeviceId, EndTime: entry.EndTime, EntryId: entry.EntryId},
)
}
return lib.SendDeletionRequest(deletionRequest)
return lib.SendDeletionRequest(ctx, deletionRequest)
}
func init() {

View File

@ -58,8 +58,8 @@ func maybeSubmitPendingDeletionRequests(ctx context.Context) error {
// Upload the missing deletion requests
for _, dr := range config.PendingDeletionRequests {
err := lib.SendDeletionRequest(dr)
if lib.IsOfflineError(err) {
err := lib.SendDeletionRequest(ctx, dr)
if lib.IsOfflineError(ctx, err) {
// We're still offline, so nothing to do
return nil
}
@ -94,7 +94,7 @@ func maybeUploadSkippedHistoryEntries(ctx context.Context) error {
if err != nil {
return err
}
_, err = lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
_, err = lib.ApiPost(ctx, "/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
if err != nil {
// Failed to upload the history entry, so we must still be offline. So just return nil and we'll try again later.
return nil
@ -110,9 +110,9 @@ func maybeUploadSkippedHistoryEntries(ctx context.Context) error {
return nil
}
func handlePotentialUploadFailure(err error, config *hctx.ClientConfig, entryTimestamp time.Time) {
func handlePotentialUploadFailure(ctx context.Context, err error, config *hctx.ClientConfig, entryTimestamp time.Time) {
if err != nil {
if lib.IsOfflineError(err) {
if lib.IsOfflineError(ctx, err) {
hctx.GetLogger().Infof("Failed to remotely persist hishtory entry because we failed to connect to the remote server! This is likely because the device is offline, but also could be because the remote server is having reliability issues. Original error: %v", err)
if !config.HaveMissedUploads {
config.HaveMissedUploads = true
@ -165,8 +165,8 @@ func presaveHistoryEntry(ctx context.Context) {
if !config.IsOffline {
jsonValue, err := lib.EncryptAndMarshal(config, []*data.HistoryEntry{entry})
lib.CheckFatalError(err)
_, err = lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
handlePotentialUploadFailure(err, config, entry.StartTime)
_, err = lib.ApiPost(ctx, "/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
handlePotentialUploadFailure(ctx, err, config, entry.StartTime)
}
}
@ -197,8 +197,8 @@ func saveHistoryEntry(ctx context.Context) {
if !config.IsOffline {
jsonValue, err := lib.EncryptAndMarshal(config, []*data.HistoryEntry{entry})
lib.CheckFatalError(err)
w, err := lib.ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
handlePotentialUploadFailure(err, config, entry.StartTime)
w, err := lib.ApiPost(ctx, "/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
handlePotentialUploadFailure(ctx, err, config, entry.StartTime)
if err == nil {
submitResponse := shared.SubmitResponse{}
err := json.Unmarshal(w, &submitResponse)
@ -264,8 +264,8 @@ func deletePresavedEntries(ctx context.Context, entry *data.HistoryEntry) error
// Note that we aren't specifying an EndTime here since pre-saved entries don't have an EndTime
shared.MessageIdentifier{DeviceId: presavedEntry.DeviceId, EntryId: presavedEntry.EntryId},
)
err = lib.SendDeletionRequest(deletionRequest)
if lib.IsOfflineError(err) {
err = lib.SendDeletionRequest(ctx, deletionRequest)
if lib.IsOfflineError(ctx, err) {
// Cache the deletion request to send once the client comes back online
config.PendingDeletionRequests = append(config.PendingDeletionRequests, deletionRequest)
return hctx.SetConfig(config)
@ -298,7 +298,7 @@ func handleDumpRequests(ctx context.Context, dumpRequests []*shared.DumpRequest)
for _, dumpRequest := range dumpRequests {
if !config.IsOffline {
// TODO: Test whether this fails if the data is extremely large? It may need to be chunked
_, err := lib.ApiPost("/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody)
_, err := lib.ApiPost(ctx, "/api/v1/submit-dump?user_id="+dumpRequest.UserId+"&requesting_device_id="+dumpRequest.RequestingDeviceId+"&source_device_id="+config.DeviceId, "application/json", reqBody)
lib.CheckFatalError(err)
}
}

View File

@ -30,7 +30,7 @@ var updateCmd = &cobra.Command{
func update(ctx context.Context) error {
// Download the binary
downloadData, err := lib.GetDownloadData()
downloadData, err := lib.GetDownloadData(ctx)
if err != nil {
return err
}

View File

@ -111,9 +111,9 @@ func OpenLocalSqliteDb() (*gorm.DB, error) {
type hishtoryContextKey string
const (
configCtxKey hishtoryContextKey = "config"
dbCtxKey hishtoryContextKey = "db"
homedirCtxKey hishtoryContextKey = "homedir"
ConfigCtxKey hishtoryContextKey = "config"
DbCtxKey hishtoryContextKey = "db"
HomedirCtxKey hishtoryContextKey = "homedir"
)
func MakeContext() context.Context {
@ -122,22 +122,22 @@ func MakeContext() context.Context {
if err != nil {
panic(fmt.Errorf("failed to retrieve config: %w", err))
}
ctx = context.WithValue(ctx, configCtxKey, &config)
ctx = context.WithValue(ctx, ConfigCtxKey, &config)
db, err := OpenLocalSqliteDb()
if err != nil {
panic(fmt.Errorf("failed to open local DB: %w", err))
}
ctx = context.WithValue(ctx, dbCtxKey, db)
ctx = context.WithValue(ctx, DbCtxKey, db)
homedir, err := os.UserHomeDir()
if err != nil {
panic(fmt.Errorf("failed to get homedir: %w", err))
}
ctx = context.WithValue(ctx, homedirCtxKey, homedir)
ctx = context.WithValue(ctx, HomedirCtxKey, homedir)
return ctx
}
func GetConf(ctx context.Context) *ClientConfig {
v := ctx.Value(configCtxKey)
v := ctx.Value(ConfigCtxKey)
if v != nil {
return (v.(*ClientConfig))
}
@ -145,7 +145,7 @@ func GetConf(ctx context.Context) *ClientConfig {
}
func GetDb(ctx context.Context) *gorm.DB {
v := ctx.Value(dbCtxKey)
v := ctx.Value(DbCtxKey)
if v != nil {
return v.(*gorm.DB)
}
@ -153,7 +153,7 @@ func GetDb(ctx context.Context) *gorm.DB {
}
func GetHome(ctx context.Context) string {
v := ctx.Value(homedirCtxKey)
v := ctx.Value(HomedirCtxKey)
if v != nil {
return v.(string)
}

View File

@ -53,6 +53,7 @@ var GitCommit string = "Unknown"
// Funnily enough, 256KB actually wasn't enough. See https://github.com/ddworken/hishtory/issues/93
var maxSupportedLineLengthForImport = 512_000
// TODO: move this function to install.go
func Setup(userSecret string, isOffline bool) error {
if userSecret == "" {
userSecret = uuid.Must(uuid.NewRandom()).String()
@ -82,12 +83,13 @@ func Setup(userSecret string, isOffline bool) error {
if config.IsOffline {
return nil
}
_, err = ApiGet("/api/v1/register?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
ctx := hctx.MakeContext()
_, err = ApiGet(ctx, "/api/v1/register?user_id="+data.UserId(userSecret)+"&device_id="+config.DeviceId)
if err != nil {
return fmt.Errorf("failed to register device with backend: %w", err)
}
respBody, err := ApiGet("/api/v1/bootstrap?user_id=" + data.UserId(userSecret) + "&device_id=" + config.DeviceId)
respBody, err := ApiGet(ctx, "/api/v1/bootstrap?user_id="+data.UserId(userSecret)+"&device_id="+config.DeviceId)
if err != nil {
return fmt.Errorf("failed to bootstrap device from the backend: %w", err)
}
@ -548,8 +550,8 @@ func getServerHostname() string {
return "https://api.hishtory.dev"
}
func GetDownloadData() (shared.UpdateInfo, error) {
respBody, err := ApiGet("/api/v1/download")
func GetDownloadData(ctx context.Context) (shared.UpdateInfo, error) {
respBody, err := ApiGet(ctx, "/api/v1/download")
if err != nil {
return shared.UpdateInfo{}, fmt.Errorf("failed to download update info: %w", err)
}
@ -565,7 +567,7 @@ func httpClient() *http.Client {
return &http.Client{}
}
func ApiGet(path string) ([]byte, error) {
func ApiGet(ctx context.Context, path string) ([]byte, error) {
if os.Getenv("HISHTORY_SIMULATE_NETWORK_ERROR") != "" {
return nil, fmt.Errorf("simulated network error: dial tcp: lookup api.hishtory.dev")
}
@ -575,6 +577,8 @@ func ApiGet(path string) ([]byte, error) {
return nil, fmt.Errorf("failed to create GET: %w", err)
}
req.Header.Set("X-Hishtory-Version", "v0."+Version)
req.Header.Set("X-Hishtory-Device-Id", hctx.GetConf(ctx).DeviceId)
req.Header.Set("X-Hishtory-User-Id", data.UserId(hctx.GetConf(ctx).UserSecret))
resp, err := httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("failed to GET %s%s: %w", getServerHostname(), path, err)
@ -592,17 +596,19 @@ func ApiGet(path string) ([]byte, error) {
return respBody, nil
}
func ApiPost(path, contentType string, data []byte) ([]byte, error) {
func ApiPost(ctx context.Context, path, contentType string, reqBody []byte) ([]byte, error) {
if os.Getenv("HISHTORY_SIMULATE_NETWORK_ERROR") != "" {
return nil, fmt.Errorf("simulated network error: dial tcp: lookup api.hishtory.dev")
}
start := time.Now()
req, err := http.NewRequest("POST", getServerHostname()+path, bytes.NewBuffer(data))
req, err := http.NewRequest("POST", getServerHostname()+path, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("failed to create POST: %w", err)
}
req.Header.Set("Content-Type", contentType)
req.Header.Set("X-Hishtory-Version", "v0."+Version)
req.Header.Set("X-Hishtory-Device-Id", hctx.GetConf(ctx).DeviceId)
req.Header.Set("X-Hishtory-User-Id", data.UserId(hctx.GetConf(ctx).UserSecret))
resp, err := httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("failed to POST %s: %w", getServerHostname()+path, err)
@ -620,7 +626,7 @@ func ApiPost(path, contentType string, data []byte) ([]byte, error) {
return respBody, nil
}
func IsOfflineError(err error) bool {
func IsOfflineError(ctx context.Context, err error) bool {
if err == nil {
return false
}
@ -637,7 +643,7 @@ func IsOfflineError(err error) bool {
strings.Contains(err.Error(), "connect: connection refused") {
return true
}
if !isHishtoryServerUp() {
if !isHishtoryServerUp(ctx) {
// If the backend server is down, then treat all errors as offline errors
return true
}
@ -645,8 +651,8 @@ func IsOfflineError(err error) bool {
return false
}
func isHishtoryServerUp() bool {
_, err := ApiGet("/api/v1/ping")
func isHishtoryServerUp(ctx context.Context) bool {
_, err := ApiGet(ctx, "/api/v1/ping")
return err == nil
}
@ -773,7 +779,7 @@ func Reupload(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to reupload due to failed encryption: %w", err)
}
_, err = ApiPost("/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
_, err = ApiPost(ctx, "/api/v1/submit?source_device_id="+config.DeviceId, "application/json", jsonValue)
if err != nil {
return fmt.Errorf("failed to reupload due to failed POST: %w", err)
}
@ -790,8 +796,8 @@ func RetrieveAdditionalEntriesFromRemote(ctx context.Context) error {
if config.IsOffline {
return nil
}
respBody, err := ApiGet("/api/v1/query?device_id=" + config.DeviceId + "&user_id=" + data.UserId(config.UserSecret))
if IsOfflineError(err) {
respBody, err := ApiGet(ctx, "/api/v1/query?device_id="+config.DeviceId+"&user_id="+data.UserId(config.UserSecret))
if IsOfflineError(ctx, err) {
return nil
}
if err != nil {
@ -817,8 +823,8 @@ func ProcessDeletionRequests(ctx context.Context) error {
if config.IsOffline {
return nil
}
resp, err := ApiGet("/api/v1/get-deletion-requests?user_id=" + data.UserId(config.UserSecret) + "&device_id=" + config.DeviceId)
if IsOfflineError(err) {
resp, err := ApiGet(ctx, "/api/v1/get-deletion-requests?user_id="+data.UserId(config.UserSecret)+"&device_id="+config.DeviceId)
if IsOfflineError(ctx, err) {
return nil
}
if err != nil {
@ -856,7 +862,7 @@ func GetBanner(ctx context.Context) ([]byte, error) {
return []byte{}, nil
}
url := "/api/v1/banner?commit_hash=" + GitCommit + "&user_id=" + data.UserId(config.UserSecret) + "&device_id=" + config.DeviceId + "&version=" + Version + "&forced_banner=" + os.Getenv("FORCED_BANNER")
return ApiGet(url)
return ApiGet(ctx, url)
}
func parseTimeGenerously(input string) (time.Time, error) {
@ -1110,12 +1116,12 @@ func unescape(query string) string {
return string(newQuery)
}
func SendDeletionRequest(deletionRequest shared.DeletionRequest) error {
func SendDeletionRequest(ctx context.Context, deletionRequest shared.DeletionRequest) error {
data, err := json.Marshal(deletionRequest)
if err != nil {
return err
}
_, err = ApiPost("/api/v1/add-deletion-request", "application/json", data)
_, err = ApiPost(ctx, "/api/v1/add-deletion-request", "application/json", data)
if err != nil {
return fmt.Errorf("failed to send deletion request to backend service, this may cause commands to not get deleted on other instances of hishtory: %w", err)
}

View File

@ -362,13 +362,14 @@ func TestAugmentedIsOfflineError(t *testing.T) {
defer testutils.BackupAndRestore(t)()
defer testutils.RunTestServer()()
defer testutils.BackupAndRestoreEnv("HISHTORY_SIMULATE_NETWORK_ERROR")()
ctx := hctx.MakeContext()
// By default, when the hishtory server is up, then IsOfflineError checks the error msg
require.True(t, isHishtoryServerUp())
require.False(t, IsOfflineError(fmt.Errorf("unchecked error type")))
require.True(t, isHishtoryServerUp(ctx))
require.False(t, IsOfflineError(ctx, fmt.Errorf("unchecked error type")))
// When the hishtory server is down, then all error messages are treated as being due to offline errors
os.Setenv("HISHTORY_SIMULATE_NETWORK_ERROR", "1")
require.False(t, isHishtoryServerUp())
require.True(t, IsOfflineError(fmt.Errorf("unchecked error type")))
require.False(t, isHishtoryServerUp(ctx))
require.True(t, IsOfflineError(ctx, fmt.Errorf("unchecked error type")))
}

View File

@ -46,7 +46,7 @@ func VerifyBinary(ctx context.Context, binaryPath, attestationPath, versionTag s
if os.Getenv("HISHTORY_DISABLE_SLSA_ATTESTATION") == "true" {
return nil
}
resp, err := ApiGet("/api/v1/slsa-status?newVersion=" + versionTag)
resp, err := ApiGet(ctx, "/api/v1/slsa-status?newVersion="+versionTag)
if err != nil {
return nil
}

View File

@ -2,10 +2,12 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"regexp"
"runtime"
@ -351,7 +353,7 @@ func captureTerminalOutputComplex(t testing.TB, captureConfig TmuxCaptureConfig)
}
func assertNoLeakedConnections(t testing.TB) {
resp, err := lib.ApiGet("/api/v1/get-num-connections")
resp, err := lib.ApiGet(makeTestOnlyContextWithFakeConfig(), "/api/v1/get-num-connections")
require.NoError(t, err)
numConnections, err := strconv.Atoi(string(resp))
require.NoError(t, err)
@ -368,6 +370,21 @@ func getPidofCommand() string {
return "pidof"
}
func makeTestOnlyContextWithFakeConfig() context.Context {
fakeConfig := hctx.ClientConfig{
UserSecret: "FAKE_TEST_DEVICE",
DeviceId: "FAKE_TEST_DEVICE",
}
ctx := context.Background()
ctx = context.WithValue(ctx, hctx.ConfigCtxKey, &fakeConfig)
// Note: We don't create a DB here
homedir, err := os.UserHomeDir()
if err != nil {
panic(fmt.Errorf("failed to get homedir: %w", err))
}
return context.WithValue(ctx, hctx.HomedirCtxKey, homedir)
}
type deviceSet struct {
deviceMap *map[device]deviceOp
currentDevice *device

View File

@ -675,7 +675,7 @@ func deleteHistoryEntry(ctx context.Context, entry data.HistoryEntry) error {
dr.Messages.Ids = append(dr.Messages.Ids,
shared.MessageIdentifier{DeviceId: entry.DeviceId, EndTime: entry.EndTime, EntryId: entry.EntryId},
)
return lib.SendDeletionRequest(dr)
return lib.SendDeletionRequest(ctx, dr)
}
func TuiQuery(ctx context.Context, initialQuery string) error {
@ -712,7 +712,7 @@ func TuiQuery(ctx context.Context, initialQuery string) error {
go func() {
banner, err := lib.GetBanner(ctx)
if err != nil {
if lib.IsOfflineError(err) {
if lib.IsOfflineError(ctx, err) {
p.Send(offlineMsg{})
} else {
p.Send(err)