Add more thorough tests for updates

This commit is contained in:
David Dworken 2023-10-09 21:41:30 -07:00
parent 82a5e2fced
commit f0dbcd6e3b
No known key found for this signature in database
7 changed files with 204 additions and 111 deletions

View File

@ -52,8 +52,8 @@ func (s *Server) apiSubmitHandler(w http.ResponseWriter, r *http.Request) {
deviceId := getOptionalQueryParam(r, "source_device_id", s.isTestEnvironment)
if deviceId != "" {
hv, err := parseVersionString(version)
if err != nil || hv.greaterThan(parsedVersion{0, 221}) {
hv, err := shared.ParseVersionString(version)
if err != nil || hv.GreaterThan(shared.ParsedVersion{0, 221}) {
// Note that if we fail to parse the version string, we do return dump and deletion requests. This is necessary
// since tests run with v0.Unknown which obviously fails to parse.
dumpRequests, err := s.db.DumpRequestForUserAndDevice(r.Context(), userId, deviceId)

View File

@ -604,44 +604,3 @@ func deserializeSubmitResponse(t *testing.T, w *httptest.ResponseRecorder) share
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &submitResponse))
return submitResponse
}
func TestParseVersionString(t *testing.T) {
p, err := parseVersionString("v0.200")
require.NoError(t, err)
require.Equal(t, parsedVersion{majorVersion: 0, minorVersion: 200}, p)
p, err = parseVersionString("v1.200")
require.NoError(t, err)
require.Equal(t, parsedVersion{majorVersion: 1, minorVersion: 200}, p)
p, err = parseVersionString("v1.0")
require.NoError(t, err)
require.Equal(t, parsedVersion{majorVersion: 1, minorVersion: 0}, p)
p, err = parseVersionString("v0.216")
require.NoError(t, err)
require.Equal(t, parsedVersion{majorVersion: 0, minorVersion: 216}, p)
p, err = parseVersionString("v123.456")
require.NoError(t, err)
require.Equal(t, parsedVersion{majorVersion: 123, minorVersion: 456}, p)
}
func TestVersionLessThan(t *testing.T) {
require.False(t, parsedVersion{0, 200}.lessThan(parsedVersion{0, 200}))
require.False(t, parsedVersion{1, 200}.lessThan(parsedVersion{1, 200}))
require.False(t, parsedVersion{0, 201}.lessThan(parsedVersion{0, 200}))
require.False(t, parsedVersion{1, 0}.lessThan(parsedVersion{0, 200}))
require.True(t, parsedVersion{0, 199}.lessThan(parsedVersion{0, 200}))
require.True(t, parsedVersion{0, 200}.lessThan(parsedVersion{0, 205}))
require.True(t, parsedVersion{1, 200}.lessThan(parsedVersion{1, 205}))
require.True(t, parsedVersion{0, 200}.lessThan(parsedVersion{1, 1}))
}
func TestVersionGreaterThan(t *testing.T) {
require.False(t, parsedVersion{0, 200}.greaterThan(parsedVersion{0, 200}))
require.False(t, parsedVersion{1, 200}.greaterThan(parsedVersion{1, 200}))
require.True(t, parsedVersion{0, 201}.greaterThan(parsedVersion{0, 200}))
require.True(t, parsedVersion{1, 0}.greaterThan(parsedVersion{0, 200}))
require.True(t, parsedVersion{1, 1}.greaterThan(parsedVersion{1, 0}))
require.False(t, parsedVersion{0, 199}.greaterThan(parsedVersion{0, 200}))
require.False(t, parsedVersion{0, 200}.greaterThan(parsedVersion{0, 205}))
require.False(t, parsedVersion{1, 200}.greaterThan(parsedVersion{1, 205}))
require.False(t, parsedVersion{0, 200}.greaterThan(parsedVersion{1, 1}))
}

View File

@ -6,7 +6,6 @@ import (
"net/http"
pprofhttp "net/http/pprof"
"os"
"regexp"
"runtime"
"strconv"
@ -99,42 +98,3 @@ func checkGormError(err error) {
_, filename, line, _ := runtime.Caller(1)
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err))
}
type parsedVersion struct {
majorVersion int
minorVersion int
}
func (pv parsedVersion) greaterThan(other parsedVersion) bool {
if pv.majorVersion == other.majorVersion && pv.minorVersion == other.minorVersion {
return false
}
return !pv.lessThan(other)
}
func (pv parsedVersion) lessThan(other parsedVersion) bool {
if pv.majorVersion != other.majorVersion {
return pv.majorVersion < other.majorVersion
}
return pv.minorVersion < other.minorVersion
}
func parseVersionString(versionString string) (parsedVersion, error) {
re := regexp.MustCompile(`v(\d+)[.](\d+)`)
matches := re.FindAllStringSubmatch(versionString, -1)
if len(matches) != 1 {
return parsedVersion{}, fmt.Errorf("failed to parse version=%#v (matches=%#v)", versionString, matches)
}
if len(matches[0]) != 3 {
return parsedVersion{}, fmt.Errorf("failed to parse version=%#v (matches[0]=%#v)", versionString, matches[0])
}
majorVersion, err := strconv.Atoi(matches[0][1])
if err != nil {
return parsedVersion{}, fmt.Errorf("failed to parse major version %#v", matches[0][1])
}
minorVersion, err := strconv.Atoi(matches[0][2])
if err != nil {
return parsedVersion{}, fmt.Errorf("failed to parse minor version %#v", matches[0][2])
}
return parsedVersion{majorVersion, minorVersion}, nil
}

View File

@ -21,6 +21,7 @@ import (
"github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/client/hctx"
"github.com/ddworken/hishtory/client/lib"
"github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/testutils"
"github.com/stretchr/testify/require"
)
@ -77,7 +78,9 @@ func TestParam(t *testing.T) {
t.Run("testRepeatedCommandAndQuery/"+tester.ShellName(), func(t *testing.T) { testRepeatedCommandAndQuery(t, tester) })
t.Run("testRepeatedEnableDisable/"+tester.ShellName(), func(t *testing.T) { testRepeatedEnableDisable(t, tester) })
t.Run("testExcludeHiddenCommand/"+tester.ShellName(), func(t *testing.T) { testExcludeHiddenCommand(t, tester) })
t.Run("testUpdate/"+tester.ShellName(), func(t *testing.T) { testUpdate(t, tester) })
t.Run("testUpdate/head->release/"+tester.ShellName(), func(t *testing.T) { testUpdateFromHeadToRelease(t, tester) })
t.Run("testUpdate/prev->release/"+tester.ShellName(), func(t *testing.T) { testUpdateFromPrevToRelease(t, tester) })
t.Run("testUpdate/prev->current/"+tester.ShellName(), func(t *testing.T) { testUpdateFromPrevToCurrent(t, tester) })
t.Run("testAdvancedQuery/"+tester.ShellName(), func(t *testing.T) { testAdvancedQuery(t, tester) })
t.Run("testIntegration/"+tester.ShellName(), func(t *testing.T) { testIntegration(t, tester, Online) })
t.Run("testIntegration/offline/"+tester.ShellName(), func(t *testing.T) { testIntegration(t, tester, Offline) })
@ -530,7 +533,64 @@ hishtory disable`)
}
}
func testUpdate(t *testing.T, tester shellTester) {
func installFromHead(t *testing.T, tester shellTester) (string, string) {
return installHishtory(t, tester, ""), "v0.Unknown"
}
func installFromPrev(t *testing.T, tester shellTester) (string, string) {
defer testutils.BackupAndRestoreEnv("HISHTORY_FORCE_CLIENT_VERSION")()
dd, err := lib.GetDownloadData()
require.NoError(t, err)
pv, err := shared.ParseVersionString(dd.Version)
require.NoError(t, err)
previousVersion := pv.Decrement()
os.Setenv("HISHTORY_FORCE_CLIENT_VERSION", previousVersion.String())
userSecret := installHishtory(t, tester, "")
out := tester.RunInteractiveShell(t, ` hishtory update`)
require.Regexp(t, regexp.MustCompile(`Successfully updated hishtory from v0[.]Unknown to `+previousVersion.String()), out)
return userSecret, previousVersion.String()
}
func updateToRelease(t *testing.T, tester shellTester) string {
dd, err := lib.GetDownloadData()
require.NoError(t, err)
// Update
out := tester.RunInteractiveShell(t, ` hishtory update`)
require.Regexp(t, regexp.MustCompile(`Successfully updated hishtory from v0[.][a-zA-Z0-9]+ to `+dd.Version), out)
require.NotContains(t, out, "skipping SLSA validation")
// Update again and assert that it skipped the update
out = tester.RunInteractiveShell(t, ` hishtory update`)
if strings.Count(out, "\n") != 1 || !strings.Contains(out, "is already installed") {
t.Fatalf("repeated hishtory update didn't skip the update, out=%#v", out)
}
return dd.Version
}
func updateToHead(t *testing.T, tester shellTester) string {
out := tester.RunInteractiveShell(t, ` /tmp/client install`)
require.Equal(t, out, "")
return "v0.Unknown"
}
func testUpdateFromHeadToRelease(t *testing.T, tester shellTester) {
testGenericUpdate(t, tester, installFromHead, updateToRelease)
}
func testUpdateFromPrevToRelease(t *testing.T, tester shellTester) {
testGenericUpdate(t, tester, installFromPrev, updateToRelease)
}
func testUpdateFromPrevToCurrent(t *testing.T, tester shellTester) {
testGenericUpdate(t, tester, installFromPrev, updateToHead)
}
// TODO: Can we duplicate testUpdateFromPrevToCurrent to also run with the prod server?
func testGenericUpdate(t *testing.T, tester shellTester, installInitialVersion func(*testing.T, shellTester) (string, string), installUpdatedVersion func(*testing.T, shellTester) string) {
defer testutils.BackupAndRestoreEnv("HISHTORY_FORCE_CLIENT_VERSION")()
if !testutils.IsOnline() {
t.Skip("skipping because we're currently offline")
}
@ -542,45 +602,36 @@ func testUpdate(t *testing.T, tester shellTester) {
}
// Set up
defer testutils.BackupAndRestore(t)()
userSecret := installHishtory(t, tester, "")
userSecret, initialVersion := installInitialVersion(t, tester)
// Record a command before the update
tester.RunInteractiveShell(t, "echo hello")
// Check the status command
out := tester.RunInteractiveShell(t, `hishtory status`)
if out != fmt.Sprintf("hiSHtory: v0.Unknown\nEnabled: true\nSecret Key: %s\nCommit Hash: Unknown\n", userSecret) {
t.Fatalf("status command has unexpected output: %#v", out)
require.Contains(t, out, fmt.Sprintf("hiSHtory: %s\nEnabled: true\nSecret Key: %s\nCommit Hash: ", initialVersion, userSecret))
if initialVersion == "v0.Unknown" {
require.Contains(t, out, "Commit Hash: Unknown")
} else {
require.NotContains(t, out, "Commit Hash: Unknown")
}
// Update
out = tester.RunInteractiveShell(t, `hishtory update`)
isExpected, err := regexp.MatchString(`Successfully updated hishtory from v0[.]Unknown to v0.\d+`, out)
require.NoError(t, err, "regex failure")
if !isExpected {
t.Fatalf("hishtory update returned unexpected out=%#v", out)
}
require.NotContains(t, out, "skipping SLSA validation")
// Update again and assert that it skipped the update
out = tester.RunInteractiveShell(t, `hishtory update`)
if strings.Count(out, "\n") != 1 || !strings.Contains(out, "is already installed") {
t.Fatalf("repeated hishtory update didn't skip the update, out=%#v", out)
}
updatedVersion := installUpdatedVersion(t, tester)
// Then check the status command again to confirm the update worked
out = tester.RunInteractiveShell(t, `hishtory status`)
require.Contains(t, out, fmt.Sprintf("\nEnabled: true\nSecret Key: %s\nCommit Hash: ", userSecret))
require.NotContains(t, out, "\nCommit Hash: Unknown\n")
if updatedVersion != "v0.Unknown" {
require.NotContains(t, out, "\nCommit Hash: Unknown\n")
}
// Check that the history was preserved after the update
out = tester.RunInteractiveShell(t, "hishtory export -pipefail | grep -v '/tmp/client install'")
expectedOutput := "echo hello\nhishtory status\nhishtory update\nhishtory update\nhishtory status\n"
expectedOutput := "echo hello\nhishtory status\nhishtory status\n"
if diff := cmp.Diff(expectedOutput, out); diff != "" {
t.Fatalf("hishtory export mismatch (-expected +got):\n%s\nout=%#v", diff, out)
}
// TODO: write a test that updates from v.prev to latest rather than v.Unknown to latest
}
func testRepeatedCommandThenQuery(t *testing.T, tester shellTester) {

View File

@ -10,6 +10,7 @@ import (
"os/exec"
"path"
"runtime"
"strings"
"syscall"
"github.com/ddworken/hishtory/client/data"
@ -47,7 +48,7 @@ func update(ctx context.Context) error {
if runtime.GOOS == "darwin" {
slsaError = verifyBinaryMac(ctx, getTmpClientPath(), downloadData)
} else {
slsaError = lib.VerifyBinary(ctx, getTmpClientPath(), getTmpClientPath()+".intoto.jsonl", downloadData.Version)
slsaError = lib.VerifyBinary(ctx, getTmpClientPath(), getTmpClientPath()+".intoto.jsonl", getPossiblyOverriddenVersion(downloadData))
}
if slsaError != nil {
err = lib.HandleSlsaFailure(slsaError)
@ -83,7 +84,7 @@ func update(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to install update (stderr=%#v), is %s in a noexec directory? (if so, set the TMPDIR environment variable): %w", stderr.String(), getTmpClientPath(), err)
}
fmt.Printf("Successfully updated hishtory from v0.%s to %s\n", lib.Version, downloadData.Version)
fmt.Printf("Successfully updated hishtory from v0.%s to %s\n", lib.Version, getPossiblyOverriddenVersion(downloadData))
return nil
}
@ -100,13 +101,19 @@ func verifyBinaryMac(ctx context.Context, binaryPath string, downloadData shared
// go compiler.
unsignedBinaryPath := binaryPath + "-unsigned"
var err error = nil
unsignedUrl := ""
if runtime.GOOS == "darwin" && runtime.GOARCH == "amd64" {
err = downloadFile(unsignedBinaryPath, downloadData.DarwinAmd64UnsignedUrl)
unsignedUrl = downloadData.DarwinAmd64UnsignedUrl
} else if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
err = downloadFile(unsignedBinaryPath, downloadData.DarwinArm64UnsignedUrl)
unsignedUrl = downloadData.DarwinArm64UnsignedUrl
} else {
err = fmt.Errorf("verifyBinaryMac() called for the unhandled branch GOOS=%s, GOARCH=%s", runtime.GOOS, runtime.GOARCH)
return fmt.Errorf("verifyBinaryMac() called for the unhandled branch GOOS=%s, GOARCH=%s", runtime.GOOS, runtime.GOARCH)
}
if forcedVersion := os.Getenv("HISHTORY_FORCE_CLIENT_VERSION"); forcedVersion != "" {
unsignedUrl = strings.ReplaceAll(unsignedUrl, downloadData.Version, forcedVersion)
}
err = downloadFile(unsignedBinaryPath, unsignedUrl)
if err != nil {
return err
}
@ -129,7 +136,7 @@ func verifyBinaryMac(ctx context.Context, binaryPath string, downloadData shared
}
// Step 4: Use SLSA to verify the unsigned binary
return lib.VerifyBinary(ctx, unsignedBinaryPath, getTmpClientPath()+".intoto.jsonl", downloadData.Version)
return lib.VerifyBinary(ctx, unsignedBinaryPath, getTmpClientPath()+".intoto.jsonl", getPossiblyOverriddenVersion(downloadData))
}
func assertIdenticalBinaries(bin1Path, bin2Path string) error {
@ -199,6 +206,10 @@ func downloadFiles(updateInfo shared.UpdateInfo) error {
} else {
return fmt.Errorf("no update info found for GOOS=%s, GOARCH=%s", runtime.GOOS, runtime.GOARCH)
}
if forcedVersion := os.Getenv("HISHTORY_FORCE_CLIENT_VERSION"); forcedVersion != "" {
clientUrl = strings.ReplaceAll(clientUrl, updateInfo.Version, forcedVersion)
clientProvenanceUrl = strings.ReplaceAll(clientProvenanceUrl, updateInfo.Version, forcedVersion)
}
err := downloadFile(getTmpClientPath(), clientUrl)
if err != nil {
return err
@ -210,6 +221,13 @@ func downloadFiles(updateInfo shared.UpdateInfo) error {
return nil
}
func getPossiblyOverriddenVersion(updateInfo shared.UpdateInfo) string {
if forcedVersion := os.Getenv("HISHTORY_FORCE_CLIENT_VERSION"); forcedVersion != "" {
return forcedVersion
}
return updateInfo.Version
}
func getTmpClientPath() string {
tmpDir := "/tmp/"
if os.Getenv("TMPDIR") != "" {

57
shared/version.go Normal file
View File

@ -0,0 +1,57 @@
package shared
import (
"fmt"
"regexp"
"strconv"
)
type ParsedVersion struct {
MajorVersion int
MinorVersion int
}
func (pv ParsedVersion) GreaterThan(other ParsedVersion) bool {
if pv.MajorVersion == other.MajorVersion && pv.MinorVersion == other.MinorVersion {
return false
}
return !pv.LessThan(other)
}
func (pv ParsedVersion) LessThan(other ParsedVersion) bool {
if pv.MajorVersion != other.MajorVersion {
return pv.MajorVersion < other.MajorVersion
}
return pv.MinorVersion < other.MinorVersion
}
func (pv ParsedVersion) Decrement() ParsedVersion {
if pv.MinorVersion > 1 {
return ParsedVersion{pv.MajorVersion, pv.MinorVersion - 1}
}
panic("cannot decrement() when MinorVersion == 0")
}
func (pv ParsedVersion) String() string {
return fmt.Sprintf("v%d.%d", pv.MajorVersion, pv.MinorVersion)
}
func ParseVersionString(versionString string) (ParsedVersion, error) {
re := regexp.MustCompile(`v(\d+)[.](\d+)`)
matches := re.FindAllStringSubmatch(versionString, -1)
if len(matches) != 1 {
return ParsedVersion{}, fmt.Errorf("failed to parse version=%#v (matches=%#v)", versionString, matches)
}
if len(matches[0]) != 3 {
return ParsedVersion{}, fmt.Errorf("failed to parse version=%#v (matches[0]=%#v)", versionString, matches[0])
}
MajorVersion, err := strconv.Atoi(matches[0][1])
if err != nil {
return ParsedVersion{}, fmt.Errorf("failed to parse major version %#v", matches[0][1])
}
MinorVersion, err := strconv.Atoi(matches[0][2])
if err != nil {
return ParsedVersion{}, fmt.Errorf("failed to parse minor version %#v", matches[0][2])
}
return ParsedVersion{MajorVersion, MinorVersion}, nil
}

48
shared/version_test.go Normal file
View File

@ -0,0 +1,48 @@
package shared
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestParseVersionString(t *testing.T) {
p, err := ParseVersionString("v0.200")
require.NoError(t, err)
require.Equal(t, ParsedVersion{MajorVersion: 0, MinorVersion: 200}, p)
p, err = ParseVersionString("v1.200")
require.NoError(t, err)
require.Equal(t, ParsedVersion{MajorVersion: 1, MinorVersion: 200}, p)
p, err = ParseVersionString("v1.0")
require.NoError(t, err)
require.Equal(t, ParsedVersion{MajorVersion: 1, MinorVersion: 0}, p)
p, err = ParseVersionString("v0.216")
require.NoError(t, err)
require.Equal(t, ParsedVersion{MajorVersion: 0, MinorVersion: 216}, p)
p, err = ParseVersionString("v123.456")
require.NoError(t, err)
require.Equal(t, ParsedVersion{MajorVersion: 123, MinorVersion: 456}, p)
}
func TestVersionLessThan(t *testing.T) {
require.False(t, ParsedVersion{0, 200}.LessThan(ParsedVersion{0, 200}))
require.False(t, ParsedVersion{1, 200}.LessThan(ParsedVersion{1, 200}))
require.False(t, ParsedVersion{0, 201}.LessThan(ParsedVersion{0, 200}))
require.False(t, ParsedVersion{1, 0}.LessThan(ParsedVersion{0, 200}))
require.True(t, ParsedVersion{0, 199}.LessThan(ParsedVersion{0, 200}))
require.True(t, ParsedVersion{0, 200}.LessThan(ParsedVersion{0, 205}))
require.True(t, ParsedVersion{1, 200}.LessThan(ParsedVersion{1, 205}))
require.True(t, ParsedVersion{0, 200}.LessThan(ParsedVersion{1, 1}))
}
func TestVersionGreaterThan(t *testing.T) {
require.False(t, ParsedVersion{0, 200}.GreaterThan(ParsedVersion{0, 200}))
require.False(t, ParsedVersion{1, 200}.GreaterThan(ParsedVersion{1, 200}))
require.True(t, ParsedVersion{0, 201}.GreaterThan(ParsedVersion{0, 200}))
require.True(t, ParsedVersion{1, 0}.GreaterThan(ParsedVersion{0, 200}))
require.True(t, ParsedVersion{1, 1}.GreaterThan(ParsedVersion{1, 0}))
require.False(t, ParsedVersion{0, 199}.GreaterThan(ParsedVersion{0, 200}))
require.False(t, ParsedVersion{0, 200}.GreaterThan(ParsedVersion{0, 205}))
require.False(t, ParsedVersion{1, 200}.GreaterThan(ParsedVersion{1, 205}))
require.False(t, ParsedVersion{0, 200}.GreaterThan(ParsedVersion{1, 1}))
}