Improved update flow

Using the previously added new API endpoint, the update flow can now skip updates if the latest version is already installed. This also improves the output by making it so update can print the version. Also improved the error handling.
This commit is contained in:
David Dworken 2022-04-16 20:50:02 -07:00
parent 735a98a611
commit 158f08f5c6
4 changed files with 49 additions and 18 deletions

View File

@ -461,6 +461,7 @@ func TestUpdate(t *testing.T) {
// Update
RunInteractiveBashCommands(t, `hishtory update`)
// TODO: assert on the output of ^
// Then check the status command again to confirm the update worked
out = RunInteractiveBashCommands(t, `hishtory status`)

View File

@ -40,6 +40,8 @@ var ConfigShContents string
//go:embed test_config.sh
var TestConfigShContents string
var Version string = "Unknown"
func getCwd() (string, error) {
cwd, err := os.Getwd()
if err != nil {
@ -421,17 +423,25 @@ func copyFile(src, dst string) error {
func Update() error {
// Download the binary
var stdout bytes.Buffer
var stderr bytes.Buffer
cmd := exec.Command("bash", "-c", `
curl -L -o /tmp/hishtory-client.intoto.jsonl https://api.hishtory.dev/download/hishtory-linux-amd64.intoto.jsonl;
curl -L -o /tmp/hishtory-client https://api.hishtory.dev/download/hishtory-linux-amd64;
chmod +x /tmp/hishtory-client`)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
respBody, err := ApiGet("/api/v1/download")
if err != nil {
return fmt.Errorf("failed to download update: %v, stdout=%#v, stderr=%#v", err, stdout.String(), stderr.String())
return fmt.Errorf("failed to download update info: %v", err)
}
var downloadData shared.UpdateInfo
err = json.Unmarshal(respBody, &downloadData)
if err != nil {
return fmt.Errorf("failed to parse update info: %v", err)
}
if downloadData.Version == Version {
fmt.Printf("Latest version (v0.%s) is already installed\n", Version)
}
err = downloadFile("/tmp/hishtory-client", downloadData.LinuxAmd64Url)
if err != nil {
return err
}
err = downloadFile("/tmp/hishtory-client.intoto.jsonl", downloadData.LinuxAmd64AttestationUrl)
if err != nil {
return err
}
// Verify the SLSA attestation
@ -440,7 +450,7 @@ func Update() error {
return fmt.Errorf("failed to verify SLSA provenance of the updated binary, aborting update: %v", err)
}
// Unlink the existing binary
// Unlink the existing binary so we can overwrite it even though it is still running
homedir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user's home directory: %v", err)
@ -449,16 +459,39 @@ func Update() error {
if err != nil {
return fmt.Errorf("failed to unlink %s for update: %v", path.Join(homedir, shared.HISHTORY_PATH, "hishtory"), err)
}
// Install the new one
cmd := exec.Command("chmod", "+x", "/tmp/hishtory-client")
err = cmd.Run()
if err != nil {
return fmt.Errorf("failed to chmod +x the update: %v", err)
}
cmd = exec.Command("/tmp/hishtory-client", "install")
err = cmd.Run()
if err != nil {
return fmt.Errorf("failed to update: %v", err)
}
fmt.Println("Successfully updated hishtory!")
fmt.Printf("Successfully updated hishtory from v0.%s to %s\n", Version, downloadData.Version)
return nil
}
func downloadFile(filename, url string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download file at %s to %s: %v", url, filename, err)
}
defer resp.Body.Close()
out, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to save file to %s: %v", filename, err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
return err
}
func getServerHostname() string {
if server := os.Getenv("HISHTORY_SERVER"); server != "" {
return server

View File

@ -14,10 +14,7 @@ import (
"github.com/ddworken/hishtory/shared"
)
var (
GitCommit string = "Unknown"
Version string = "Unknown"
)
var GitCommit string = "Unknown"
func main() {
if len(os.Args) == 1 {
@ -44,7 +41,7 @@ func main() {
case "status":
config, err := lib.GetConfig()
lib.CheckFatalError(err)
fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", Version, config.IsEnabled)
fmt.Printf("Hishtory: v0.%s\nEnabled: %v\n", lib.Version, config.IsEnabled)
fmt.Printf("Secret Key: %s\n", config.UserSecret)
if len(os.Args) == 3 && os.Args[2] == "-v" {
fmt.Printf("User ID: %s\n", data.UserId(config.UserSecret))

View File

@ -1,4 +1,4 @@
#!/usr/bin/env bash
GIT_HASH=$(git rev-parse HEAD)
echo "-X main.GitCommit=$GIT_HASH -X main.Version=`cat VERSION` -w -extldflags \"-static\""
echo "-X main.GitCommit=$GIT_HASH -X client.lib.Version=`cat VERSION` -w -extldflags \"-static\""