diff --git a/client/client_test.go b/client/client_test.go index 3a44073..5edea6d 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -461,6 +461,7 @@ func TestUpdate(t *testing.T) { // Update RunInteractiveBashCommands(t, `hishtory update`) + // TODO: assert on the output of ^ // Then check the status command again to confirm the update worked out = RunInteractiveBashCommands(t, `hishtory status`) diff --git a/client/lib/lib.go b/client/lib/lib.go index dbc2619..a9338ff 100644 --- a/client/lib/lib.go +++ b/client/lib/lib.go @@ -40,6 +40,8 @@ var ConfigShContents string //go:embed test_config.sh var TestConfigShContents string +var Version string = "Unknown" + func getCwd() (string, error) { cwd, err := os.Getwd() if err != nil { @@ -421,17 +423,25 @@ func copyFile(src, dst string) error { func Update() error { // Download the binary - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd := exec.Command("bash", "-c", ` - curl -L -o /tmp/hishtory-client.intoto.jsonl https://api.hishtory.dev/download/hishtory-linux-amd64.intoto.jsonl; - curl -L -o /tmp/hishtory-client https://api.hishtory.dev/download/hishtory-linux-amd64; - chmod +x /tmp/hishtory-client`) - cmd.Stdout = &stdout - cmd.Stderr = &stderr - err := cmd.Run() + respBody, err := ApiGet("/api/v1/download") if err != nil { - return fmt.Errorf("failed to download update: %v, stdout=%#v, stderr=%#v", err, stdout.String(), stderr.String()) + return fmt.Errorf("failed to download update info: %v", err) + } + var downloadData shared.UpdateInfo + err = json.Unmarshal(respBody, &downloadData) + if err != nil { + return fmt.Errorf("failed to parse update info: %v", err) + } + if downloadData.Version == Version { + fmt.Printf("Latest version (v0.%s) is already installed\n", Version) + } + err = downloadFile("/tmp/hishtory-client", downloadData.LinuxAmd64Url) + if err != nil { + return err + } + err = downloadFile("/tmp/hishtory-client.intoto.jsonl", downloadData.LinuxAmd64AttestationUrl) + if err != nil { + return err } // Verify the SLSA attestation @@ -440,7 +450,7 @@ func Update() error { return fmt.Errorf("failed to verify SLSA provenance of the updated binary, aborting update: %v", err) } - // Unlink the existing binary + // Unlink the existing binary so we can overwrite it even though it is still running homedir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("failed to get user's home directory: %v", err) @@ -449,16 +459,39 @@ func Update() error { if err != nil { return fmt.Errorf("failed to unlink %s for update: %v", path.Join(homedir, shared.HISHTORY_PATH, "hishtory"), err) } + // Install the new one + cmd := exec.Command("chmod", "+x", "/tmp/hishtory-client") + err = cmd.Run() + if err != nil { + return fmt.Errorf("failed to chmod +x the update: %v", err) + } cmd = exec.Command("/tmp/hishtory-client", "install") err = cmd.Run() if err != nil { return fmt.Errorf("failed to update: %v", err) } - fmt.Println("Successfully updated hishtory!") + fmt.Printf("Successfully updated hishtory from v0.%s to %s\n", Version, downloadData.Version) return nil } +func downloadFile(filename, url string) error { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download file at %s to %s: %v", url, filename, err) + } + defer resp.Body.Close() + + out, err := os.Create(filename) + if err != nil { + return fmt.Errorf("failed to save file to %s: %v", filename, err) + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + return err +} + func getServerHostname() string { if server := os.Getenv("HISHTORY_SERVER"); server != "" { return server diff --git a/hishtory.go b/hishtory.go index 0aaab58..88c851f 100644 --- a/hishtory.go +++ b/hishtory.go @@ -14,10 +14,7 @@ import ( "github.com/ddworken/hishtory/shared" ) -var ( - GitCommit string = "Unknown" - Version string = "Unknown" -) +var GitCommit string = "Unknown" func main() { if len(os.Args) == 1 { @@ -44,7 +41,7 @@ func main() { case "status": config, err := lib.GetConfig() lib.CheckFatalError(err) - fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", Version, config.IsEnabled) + fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", lib.Version, config.IsEnabled) fmt.Printf("Secret Key: %s\n", config.UserSecret) if len(os.Args) == 3 && os.Args[2] == "-v" { fmt.Printf("User ID: %s\n", data.UserId(config.UserSecret)) diff --git a/scripts/client-ldflags b/scripts/client-ldflags index f5aa994..350bf30 100755 --- a/scripts/client-ldflags +++ b/scripts/client-ldflags @@ -1,4 +1,4 @@ #!/usr/bin/env bash GIT_HASH=$(git rev-parse HEAD) -echo "-X main.GitCommit=$GIT_HASH -X main.Version=`cat VERSION` -w -extldflags \"-static\"" +echo "-X main.GitCommit=$GIT_HASH -X client.lib.Version=`cat VERSION` -w -extldflags \"-static\""