diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index efedda9ec..f0a5e1d5d 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -30,6 +30,7 @@ "time" "github.com/KimMachineGun/automemlimit/memlimit" + webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/admin" @@ -40,6 +41,7 @@ "github.com/superseriousbusiness/gotosocial/internal/filter/spam" "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media/ffmpeg" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/metrics" @@ -248,6 +250,22 @@ } } + // Get or create a VAPID key pair. + vapidKeyPair, err := dbService.GetVAPIDKeyPair(ctx) + if err != nil { + return gtserror.Newf("error getting VAPID key pair: %w", err) + } + if vapidKeyPair == nil { + // Generate and store a new key pair. + vapidKeyPair = >smodel.VAPIDKeyPair{} + if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { + return gtserror.Newf("error generating VAPID key pair: %w", err) + } + if err := dbService.PutVAPIDKeyPair(ctx, vapidKeyPair); err != nil { + return gtserror.Newf("error putting VAPID key pair: %w", err) + } + } + // Initialize both home / list timelines. state.Timelines.Home = timeline.NewManager( tlprocessor.HomeTimelineGrab(state), diff --git a/internal/db/admin.go b/internal/db/admin.go index d0da54e31..77fbbe613 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -68,9 +68,13 @@ type Admin interface { // the number of pending sign-ups sitting in the backlog. CountUnhandledSignups(ctx context.Context) (int, error) - // GetOrCreateVAPIDKeyPair creates and stores a VAPID key pair, - // or retrieves the existing VAPID key pair. - GetOrCreateVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) + // GetVAPIDKeyPair retrieves the existing VAPID key pair, if there is one. + // If there isn't, it returns nil. + GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) + + // PutVAPIDKeyPair stores a VAPID key pair. + // This should be called at most once, during server startup. + PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error /* ACTION FUNCS diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index dacb2cb1f..266b351f5 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -27,7 +27,6 @@ "strings" "time" - webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/google/uuid" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/config" @@ -49,6 +48,9 @@ type adminDB struct { db *bun.DB state *state.State + + // Since the VAPID key pair is very small and never written to concurrently, we can cache it here. + vapidKeyPair *gtsmodel.VAPIDKeyPair } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) { @@ -443,36 +445,37 @@ func (a *adminDB) CountUnhandledSignups(ctx context.Context) (int, error) { Count(ctx) } -func (a *adminDB) GetOrCreateVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { - var err error - var vapidKeyPair *gtsmodel.VAPIDKeyPair +func (a *adminDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { + // Look for cached keys. + if a.vapidKeyPair != nil { + return a.vapidKeyPair, nil + } - // Look for previously generated keys. - if err = a.db.NewSelect(). - Model(vapidKeyPair). + // Look for previously generated keys in the database. + if err := a.db.NewSelect(). + Model(a.vapidKeyPair). Limit(1). Scan(ctx); // nocollapse err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.Newf("DB error getting VAPID key pair: %w", err) } - if vapidKeyPair == nil { - // Generate new keys. - vapidKeyPair = >smodel.VAPIDKeyPair{} - if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { - return nil, gtserror.Newf("error generating VAPID key pair: %w", err) - } + return a.vapidKeyPair, nil +} - // Save them to the database. - if _, err = a.db.NewInsert(). - Model(vapidKeyPair). - Exec(ctx); // nocollapse - err != nil { - return nil, gtserror.Newf("DB error saving VAPID key pair: %w", err) - } +func (a *adminDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error { + // Store the keys in the database. + if _, err := a.db.NewInsert(). + Model(a.vapidKeyPair). + Exec(ctx); // nocollapse + err != nil { + return gtserror.Newf("DB error putting VAPID key pair: %w", err) } - return vapidKeyPair, err + // Cache the keys. + a.vapidKeyPair = vapidKeyPair + + return nil } /* diff --git a/testrig/db.go b/testrig/db.go index e53e9c9f0..52fe7b822 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -19,7 +19,7 @@ import ( "context" - + webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -368,6 +368,15 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { log.Panic(ctx, err) } + vapidKeyPair := >smodel.VAPIDKeyPair{} + var err error + if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { + log.Panic(nil, err) + } + if err = db.PutVAPIDKeyPair(ctx, vapidKeyPair); err != nil { + log.Panic(nil, err) + } + log.Debug(ctx, "testing db setup complete") }