Refactor server.go to remove two global variables

This commit is contained in:
David Dworken 2023-09-13 21:47:06 -07:00
parent b478eadeae
commit a66ea1387d
No known key found for this signature in database
3 changed files with 31 additions and 44 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"runtime"
"time" "time"
"github.com/DataDog/datadog-go/statsd" "github.com/DataDog/datadog-go/statsd"
@ -23,8 +22,7 @@ const (
) )
var ( var (
GLOBAL_DB *database.DB // Filled in via ldflags with the latest released version as of the server getting built
GLOBAL_STATSD *statsd.Client
ReleaseVersion string ReleaseVersion string
) )
@ -96,14 +94,6 @@ func OpenDB() (*database.DB, error) {
return db, nil return db, nil
} }
func init() {
release.Version = ReleaseVersion
if release.Version == "UNKNOWN" && !isTestEnvironment() {
panic("server.go was built without a ReleaseVersion!")
}
InitDB()
}
func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error { func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error {
if err := release.UpdateReleaseVersion(); err != nil { if err := release.UpdateReleaseVersion(); err != nil {
return fmt.Errorf("updateReleaseVersion: %w", err) return fmt.Errorf("updateReleaseVersion: %w", err)
@ -120,10 +110,10 @@ func cron(ctx context.Context, db *database.DB, stats *statsd.Client) error {
return nil return nil
} }
func runBackgroundJobs(ctx context.Context, srv *server.Server) { func runBackgroundJobs(ctx context.Context, srv *server.Server, db *database.DB, stats *statsd.Client) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
for { for {
err := cron(ctx, GLOBAL_DB, GLOBAL_STATSD) err := cron(ctx, db, stats)
if err != nil { if err != nil {
fmt.Printf("Cron failure: %v", err) fmt.Printf("Cron failure: %v", err)
@ -136,41 +126,45 @@ func runBackgroundJobs(ctx context.Context, srv *server.Server) {
} }
func InitDB() *database.DB { func InitDB() *database.DB {
var err error db, err := OpenDB()
GLOBAL_DB, err = OpenDB()
if err != nil { if err != nil {
panic(fmt.Errorf("OpenDB: %w", err)) panic(fmt.Errorf("OpenDB: %w", err))
} }
if err := GLOBAL_DB.Ping(); err != nil { if err := db.Ping(); err != nil {
panic(fmt.Errorf("ping: %w", err)) panic(fmt.Errorf("ping: %w", err))
} }
if isProductionEnvironment() { if isProductionEnvironment() {
if err := GLOBAL_DB.SetMaxIdleConns(10); err != nil { if err := db.SetMaxIdleConns(10); err != nil {
panic(fmt.Errorf("failed to set max idle conns: %w", err)) panic(fmt.Errorf("failed to set max idle conns: %w", err))
} }
} }
if isTestEnvironment() { if isTestEnvironment() {
if err := GLOBAL_DB.SetMaxIdleConns(1); err != nil { if err := db.SetMaxIdleConns(1); err != nil {
panic(fmt.Errorf("failed to set max idle conns: %w", err)) panic(fmt.Errorf("failed to set max idle conns: %w", err))
} }
} }
return GLOBAL_DB return db
} }
func main() { func main() {
s, err := statsd.New(StatsdSocket) // Startup check:
release.Version = ReleaseVersion
if release.Version == "UNKNOWN" && !isTestEnvironment() {
panic("server.go was built without a ReleaseVersion!")
}
// Create DB and stats
db := InitDB()
stats, err := statsd.New(StatsdSocket)
if err != nil { if err != nil {
fmt.Printf("Failed to start DataDog statsd: %v\n", err) fmt.Printf("Failed to start DataDog statsd: %v\n", err)
} }
// TODO: remove this global once we have a better way to pass it around
GLOBAL_STATSD = s
srv := server.NewServer( srv := server.NewServer(
GLOBAL_DB, db,
server.WithStatsd(s), server.WithStatsd(stats),
server.WithReleaseVersion(release.Version), server.WithReleaseVersion(release.Version),
server.IsTestEnvironment(isTestEnvironment()), server.IsTestEnvironment(isTestEnvironment()),
server.IsProductionEnvironment(isProductionEnvironment()), server.IsProductionEnvironment(isProductionEnvironment()),
@ -178,25 +172,12 @@ func main() {
server.WithUpdateInfo(release.BuildUpdateInfo(release.Version)), server.WithUpdateInfo(release.BuildUpdateInfo(release.Version)),
) )
go runBackgroundJobs(context.Background(), srv) go runBackgroundJobs(context.Background(), srv, db, stats)
if err := srv.Run(context.Background(), ":8080"); err != nil { if err := srv.Run(context.Background(), ":8080"); err != nil {
panic(err) panic(err)
} }
} }
func checkGormResult(result *gorm.DB) {
checkGormError(result.Error, 1)
}
func checkGormError(err error, skip int) {
if err == nil {
return
}
_, filename, line, _ := runtime.Caller(skip + 1)
panic(fmt.Sprintf("DB error at %s:%d: %v", filename, line, err))
}
// TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required?
// TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test // TODO: Add error checking for the calls to updateUsageData(...) that logs it/triggers an alert in prod, but is an error in test

View File

@ -3,13 +3,15 @@ package release
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ddworken/hishtory/shared"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/ddworken/hishtory/shared"
) )
// TODO: Can we get rid of this bit of mutable state by changing UpdateReleaseVersion to return the latest version?
var Version = "UNKNOWN" var Version = "UNKNOWN"
type releaseInfo struct { type releaseInfo struct {

View File

@ -5,9 +5,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ddworken/hishtory/internal/database"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -16,6 +13,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/ddworken/hishtory/internal/database"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/ddworken/hishtory/client/data" "github.com/ddworken/hishtory/client/data"
"github.com/ddworken/hishtory/shared" "github.com/ddworken/hishtory/shared"
"github.com/ddworken/hishtory/shared/testutils" "github.com/ddworken/hishtory/shared/testutils"
@ -39,7 +40,10 @@ func TestMain(m *testing.M) {
} }
underlyingDb.SetMaxOpenConns(1) underlyingDb.SetMaxOpenConns(1)
db.Exec("PRAGMA journal_mode = WAL") db.Exec("PRAGMA journal_mode = WAL")
db.AddDatabaseTables() err = db.AddDatabaseTables()
if err != nil {
panic(fmt.Errorf("failed to add database tables: %w", err))
}
DB = db DB = db