From 0d30011a33622c6bac13e8bdd2cd0bb0ab4b7ca3 Mon Sep 17 00:00:00 2001 From: Sergio Moura Date: Tue, 12 Sep 2023 10:09:38 -0400 Subject: [PATCH] break down release versions and fix server tests --- backend/server/Dockerfile | 2 +- backend/server/server.go | 130 ++------------------ internal/release/release.go | 127 +++++++++++++++++++ internal/release/release_test.go | 33 +++++ {backend => internal}/server/server_test.go | 97 +++++++-------- internal/server/srv.go | 5 + shared/testutils/testutils.go | 2 +- 7 files changed, 225 insertions(+), 171 deletions(-) create mode 100644 internal/release/release.go create mode 100644 internal/release/release_test.go rename {backend => internal}/server/server_test.go (93%) diff --git a/backend/server/Dockerfile b/backend/server/Dockerfile index 55496c4..909cc61 100644 --- a/backend/server/Dockerfile +++ b/backend/server/Dockerfile @@ -6,7 +6,7 @@ RUN go mod download COPY . ./ ARG GOARCH RUN apk add --update --no-cache --virtual .build-deps build-base && \ - GOARCH=${GOARCH} go build -o /server -ldflags "-X main.ReleaseVersion=v0.`cat VERSION`" backend/server/server.go && \ + GOARCH=${GOARCH} go build -o /server -ldflags "-X release.Version=v0.`cat VERSION`" backend/server/server.go && \ apk del .build-deps FROM alpine:3.17 diff --git a/backend/server/server.go b/backend/server/server.go index bc65d50..9a4e847 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -2,21 +2,16 @@ package main import ( "context" - "encoding/json" "fmt" + "github.com/ddworken/hishtory/internal/release" "github.com/ddworken/hishtory/internal/server" - "io" "log" - "net/http" "os" "runtime" - "strconv" - "strings" "time" "github.com/DataDog/datadog-go/statsd" "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/shared" _ "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -28,9 +23,8 @@ const ( ) var ( - GLOBAL_DB *database.DB - GLOBAL_STATSD *statsd.Client - ReleaseVersion string = "UNKNOWN" + GLOBAL_DB *database.DB + GLOBAL_STATSD *statsd.Client ) func isTestEnvironment() bool { @@ -102,15 +96,14 @@ func OpenDB() (*database.DB, error) { } func init() { - if ReleaseVersion == "UNKNOWN" && !isTestEnvironment() { + if release.Version == "UNKNOWN" && !isTestEnvironment() { panic("server.go was built without a ReleaseVersion!") } InitDB() - go runBackgroundJobs(context.Background()) } func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { - if err := updateReleaseVersion(); err != nil { + if err := release.UpdateReleaseVersion(); err != nil { return fmt.Errorf("updateReleaseVersion: %w", err) } @@ -125,7 +118,7 @@ func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { return nil } -func runBackgroundJobs(ctx context.Context) { +func runBackgroundJobs(ctx context.Context, srv *server.Server) { time.Sleep(5 * time.Second) for { err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD) @@ -135,80 +128,12 @@ func runBackgroundJobs(ctx context.Context) { // cron no longer panics, panicking here. panic(err) } + srv.UpdateReleaseVersion(release.Version, release.BuildUpdateInfo(release.Version)) time.Sleep(10 * time.Minute) } } -type releaseInfo struct { - Name string `json:"name"` -} - -func updateReleaseVersion() error { - resp, err := http.Get("https://api.github.com/repos/ddworken/hishtory/releases/latest") - if err != nil { - return fmt.Errorf("failed to get latest release version: %w", err) - } - defer resp.Body.Close() - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read github API response body: %w", err) - } - if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") { - return nil - } - if resp.StatusCode != 200 { - return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody)) - } - var info releaseInfo - err = json.Unmarshal(respBody, &info) - if err != nil { - return fmt.Errorf("failed to parse github API response: %w", err) - } - latestVersionTag := info.Name - ReleaseVersion = decrementVersionIfInvalid(latestVersionTag) - return nil -} - -func decrementVersionIfInvalid(initialVersion string) string { - // Decrements the version up to 5 times if the version doesn't have valid binaries yet. - version := initialVersion - for i := 0; i < 5; i++ { - updateInfo := buildUpdateInfo(version) - err := assertValidUpdate(updateInfo) - if err == nil { - fmt.Printf("Found a valid version: %v\n", version) - return version - } - fmt.Printf("Found %s to be an invalid version: %v\n", version, err) - version, err = decrementVersion(version) - if err != nil { - fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err) - return initialVersion - } - } - fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version) - return initialVersion -} - -func assertValidUpdate(updateInfo shared.UpdateInfo) error { - urls := []string{updateInfo.LinuxAmd64Url, updateInfo.LinuxAmd64AttestationUrl, updateInfo.LinuxArm64Url, updateInfo.LinuxArm64AttestationUrl, - updateInfo.LinuxArm7Url, updateInfo.LinuxArm7AttestationUrl, - updateInfo.DarwinAmd64Url, updateInfo.DarwinAmd64UnsignedUrl, updateInfo.DarwinAmd64AttestationUrl, - updateInfo.DarwinArm64Url, updateInfo.DarwinArm64UnsignedUrl, updateInfo.DarwinArm64AttestationUrl} - for _, url := range urls { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to retrieve URL %#v: %w", url, err) - } - defer resp.Body.Close() - if resp.StatusCode == 404 { - return fmt.Errorf("URL %#v returned 404", url) - } - } - return nil -} - -func InitDB() { +func InitDB() *database.DB { var err error GLOBAL_DB, err = OpenDB() if err != nil { @@ -228,39 +153,8 @@ func InitDB() { panic(fmt.Errorf("failed to set max idle conns: %w", err)) } } -} -func decrementVersion(version string) (string, error) { - if version == "UNKNOWN" { - return "", fmt.Errorf("cannot decrement UNKNOWN") - } - parts := strings.Split(version, ".") - if len(parts) != 2 { - return "", fmt.Errorf("invalid version: %s", version) - } - versionNumber, err := strconv.Atoi(parts[1]) - if err != nil { - return "", fmt.Errorf("invalid version: %s", version) - } - return parts[0] + "." + strconv.Itoa(versionNumber-1), nil -} - -func buildUpdateInfo(version string) shared.UpdateInfo { - return shared.UpdateInfo{ - LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version), - LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version), - LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version), - LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version), - LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version), - LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version), - DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version), - DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version), - DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version), - DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version), - DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version), - DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version), - Version: version, - } + return GLOBAL_DB } func main() { @@ -275,13 +169,15 @@ func main() { srv := server.NewServer( GLOBAL_DB, server.WithStatsd(s), - server.WithReleaseVersion(ReleaseVersion), + server.WithReleaseVersion(release.Version), server.IsTestEnvironment(isTestEnvironment()), server.IsProductionEnvironment(isProductionEnvironment()), server.WithCron(cron), - server.WithUpdateInfo(buildUpdateInfo(ReleaseVersion)), + server.WithUpdateInfo(release.BuildUpdateInfo(release.Version)), ) + go runBackgroundJobs(context.Background(), srv) + if err := srv.Run(context.Background(), ":8080"); err != nil { panic(err) } diff --git a/internal/release/release.go b/internal/release/release.go new file mode 100644 index 0000000..36b03c3 --- /dev/null +++ b/internal/release/release.go @@ -0,0 +1,127 @@ +package release + +import ( + "encoding/json" + "fmt" + "github.com/ddworken/hishtory/shared" + "io" + "net/http" + "strconv" + "strings" +) + +var Version = "UNKNOWN" + +type releaseInfo struct { + Name string `json:"name"` +} + +const releaseURL = "https://api.github.com/repos/ddworken/hishtory/releases/latest" + +func UpdateReleaseVersion() error { + resp, err := http.Get(releaseURL) + if err != nil { + return fmt.Errorf("failed to get latest release version: %w", err) + } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read github API response body: %w", err) + } + if resp.StatusCode == 403 && strings.Contains(string(respBody), "API rate limit exceeded for ") { + return nil + } + if resp.StatusCode != 200 { + return fmt.Errorf("failed to call github API, status_code=%d, body=%#v", resp.StatusCode, string(respBody)) + } + var info releaseInfo + err = json.Unmarshal(respBody, &info) + if err != nil { + return fmt.Errorf("failed to parse github API response: %w", err) + } + latestVersionTag := info.Name + Version = decrementVersionIfInvalid(latestVersionTag) + return nil +} + +func BuildUpdateInfo(version string) shared.UpdateInfo { + return shared.UpdateInfo{ + LinuxAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64", version), + LinuxAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-amd64.intoto.jsonl", version), + LinuxArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64", version), + LinuxArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm64.intoto.jsonl", version), + LinuxArm7Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm", version), + LinuxArm7AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-linux-arm.intoto.jsonl", version), + DarwinAmd64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64", version), + DarwinAmd64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64-unsigned", version), + DarwinAmd64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-amd64.intoto.jsonl", version), + DarwinArm64Url: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64", version), + DarwinArm64UnsignedUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64-unsigned", version), + DarwinArm64AttestationUrl: fmt.Sprintf("https://github.com/ddworken/hishtory/releases/download/%s/hishtory-darwin-arm64.intoto.jsonl", version), + Version: version, + } +} + +func decrementVersionIfInvalid(initialVersion string) string { + // Decrements the version up to 5 times if the version doesn't have valid binaries yet. + version := initialVersion + for i := 0; i < 5; i++ { + updateInfo := BuildUpdateInfo(version) + err := assertValidUpdate(updateInfo) + if err == nil { + fmt.Printf("Found a valid version: %v\n", version) + return version + } + fmt.Printf("Found %s to be an invalid version: %v\n", version, err) + version, err = decrementVersion(version) + if err != nil { + fmt.Printf("Failed to decrement version after finding the latest version was invalid: %v\n", err) + return initialVersion + } + } + fmt.Printf("Decremented the version 5 times and failed to find a valid version version number, initial version number: %v, last checked version number: %v\n", initialVersion, version) + return initialVersion +} + +func assertValidUpdate(updateInfo shared.UpdateInfo) error { + urls := []string{ + updateInfo.LinuxAmd64Url, + updateInfo.LinuxAmd64AttestationUrl, + updateInfo.LinuxArm64Url, + updateInfo.LinuxArm64AttestationUrl, + updateInfo.LinuxArm7Url, + updateInfo.LinuxArm7AttestationUrl, + updateInfo.DarwinAmd64Url, + updateInfo.DarwinAmd64UnsignedUrl, + updateInfo.DarwinAmd64AttestationUrl, + updateInfo.DarwinArm64Url, + updateInfo.DarwinArm64UnsignedUrl, + updateInfo.DarwinArm64AttestationUrl, + } + for _, url := range urls { + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to retrieve URL %#v: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return fmt.Errorf("URL %#v returned 404", url) + } + } + return nil +} + +func decrementVersion(version string) (string, error) { + if version == "UNKNOWN" { + return "", fmt.Errorf("cannot decrement UNKNOWN") + } + parts := strings.Split(version, ".") + if len(parts) != 2 { + return "", fmt.Errorf("invalid version: %s", version) + } + versionNumber, err := strconv.Atoi(parts[1]) + if err != nil { + return "", fmt.Errorf("invalid version: %s", version) + } + return parts[0] + "." + strconv.Itoa(versionNumber-1), nil +} diff --git a/internal/release/release_test.go b/internal/release/release_test.go new file mode 100644 index 0000000..5cba118 --- /dev/null +++ b/internal/release/release_test.go @@ -0,0 +1,33 @@ +package release + +import ( + "github.com/ddworken/hishtory/shared/testutils" + "strings" + "testing" +) + +func TestUpdateReleaseVersion(t *testing.T) { + if !testutils.IsOnline() { + t.Skip("skipping because we're currently offline") + } + + // Check that ReleaseVersion hasn't been set yet + if Version != "UNKNOWN" { + t.Fatalf("initial ReleaseVersion isn't as expected: %#v", Version) + } + + // Update it + err := UpdateReleaseVersion() + if err != nil { + t.Fatalf("updateReleaseVersion failed: %v", err) + } + + // If ReleaseVersion is still unknown, skip because we're getting rate limited + if Version == "UNKNOWN" { + t.Skip() + } + // Otherwise, check that the new value looks reasonable + if !strings.HasPrefix(Version, "v0.") { + t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", Version) + } +} diff --git a/backend/server/server_test.go b/internal/server/server_test.go similarity index 93% rename from backend/server/server_test.go rename to internal/server/server_test.go index 3a8b26b..19dfbf8 100644 --- a/backend/server/server_test.go +++ b/internal/server/server_test.go @@ -1,12 +1,13 @@ -package main +package server import ( "bytes" "context" "encoding/json" + "fmt" "github.com/ddworken/hishtory/internal/database" - "github.com/ddworken/hishtory/internal/server" "github.com/stretchr/testify/require" + "gorm.io/gorm" "io" "net/http" "net/http/httptest" @@ -22,10 +23,32 @@ import ( "github.com/google/uuid" ) +var DB *database.DB + +const testDBDSN = "file::memory:?_journal_mode=WAL&cache=shared" + +func TestMain(m *testing.M) { + // setup test database + db, err := database.OpenSQLite(testDBDSN, &gorm.Config{}) + if err != nil { + panic(fmt.Errorf("failed to connect to the DB: %w", err)) + } + underlyingDb, err := db.DB.DB() + if err != nil { + panic(fmt.Errorf("failed to access underlying DB: %w", err)) + } + underlyingDb.SetMaxOpenConns(1) + db.Exec("PRAGMA journal_mode = WAL") + db.AddDatabaseTables() + + DB = db + + os.Exit(m.Run()) +} + func TestESubmitThenQuery(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register a few devices userId := data.UserId("key") @@ -120,13 +143,12 @@ func TestESubmitThenQuery(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestDumpRequestAndResponse(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register a first device for two different users userId := data.UserId("dkey") @@ -288,45 +310,12 @@ func TestDumpRequestAndResponse(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) -} - -func TestUpdateReleaseVersion(t *testing.T) { - if !testutils.IsOnline() { - t.Skip("skipping because we're currently offline") - } - - // Set up - InitDB() - - // Check that ReleaseVersion hasn't been set yet - if ReleaseVersion != "UNKNOWN" { - t.Fatalf("initial ReleaseVersion isn't as expected: %#v", ReleaseVersion) - } - - // Update it - err := updateReleaseVersion() - if err != nil { - t.Fatalf("updateReleaseVersion failed: %v", err) - } - - // If ReleaseVersion is still unknown, skip because we're getting rate limited - if ReleaseVersion == "UNKNOWN" { - t.Skip() - } - // Otherwise, check that the new value looks reasonable - if !strings.HasPrefix(ReleaseVersion, "v0.") { - t.Fatalf("ReleaseVersion wasn't updated to contain a version: %#v", ReleaseVersion) - } - - // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestDeletionRequests(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Register two devices for two different users userId := data.UserId("dkey") @@ -504,11 +493,11 @@ func TestDeletionRequests(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestHealthcheck(t *testing.T) { - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) w := httptest.NewRecorder() s.healthCheckHandler(w, httptest.NewRequest(http.MethodGet, "/", nil)) if w.Code != 200 { @@ -523,15 +512,20 @@ func TestHealthcheck(t *testing.T) { } // Assert that we aren't leaking connections - assertNoLeakedConnections(t, GLOBAL_DB) + assertNoLeakedConnections(t, DB) } func TestLimitRegistrations(t *testing.T) { // Set up - InitDB() - s := server.NewServer(GLOBAL_DB) - checkGormResult(GLOBAL_DB.Exec("DELETE FROM enc_history_entries")) - checkGormResult(GLOBAL_DB.Exec("DELETE FROM devices")) + s := NewServer(DB) + + if resp := DB.Exec("DELETE FROM enc_history_entries"); resp.Error != nil { + t.Fatalf("failed to delete enc_history_entries: %v", resp.Error) + } + + if resp := DB.Exec("DELETE FROM devices"); resp.Error != nil { + t.Fatalf("failed to delete devices: %v", resp.Error) + } defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() os.Setenv("HISHTORY_MAX_NUM_USERS", "2") @@ -552,8 +546,7 @@ func TestLimitRegistrations(t *testing.T) { func TestCleanDatabaseNoErrors(t *testing.T) { // Init - InitDB() - s := server.NewServer(GLOBAL_DB) + s := NewServer(DB) // Create a user and an entry userId := data.UserId("dkey") @@ -570,7 +563,7 @@ func TestCleanDatabaseNoErrors(t *testing.T) { s.apiSubmitHandler(httptest.NewRecorder(), submitReq) // Call cleanDatabase and just check that there are no panics - testutils.Check(t, GLOBAL_DB.Clean(context.TODO())) + testutils.Check(t, DB.Clean(context.TODO())) } func assertNoLeakedConnections(t *testing.T, db *database.DB) { diff --git a/internal/server/srv.go b/internal/server/srv.go index b30ed65..31b9d9b 100644 --- a/internal/server/srv.go +++ b/internal/server/srv.go @@ -125,6 +125,11 @@ func (s *Server) Run(ctx context.Context, addr string) error { return nil } +func (s *Server) UpdateReleaseVersion(v string, updateInfo shared.UpdateInfo) { + s.releaseVersion = v + s.updateInfo = updateInfo +} + func (s *Server) getDeletionRequestsHandler(w http.ResponseWriter, r *http.Request) { userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") diff --git a/shared/testutils/testutils.go b/shared/testutils/testutils.go index 77ec18a..38e5a7f 100644 --- a/shared/testutils/testutils.go +++ b/shared/testutils/testutils.go @@ -228,7 +228,7 @@ func buildServer() string { f, err := os.CreateTemp("", "server") checkError(err) fn := f.Name() - cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X main.ReleaseVersion=v0.%s", version), "backend/server/server.go") + cmd := exec.Command("go", "build", "-o", fn, "-ldflags", fmt.Sprintf("-X release.Version=v0.%s", version), "backend/server/server.go") var stdout bytes.Buffer cmd.Stdout = &stdout var stderr bytes.Buffer