diff --git a/backend/server/internal/database/db.go b/backend/server/internal/database/db.go index dcf87ff..308836c 100644 --- a/backend/server/internal/database/db.go +++ b/backend/server/internal/database/db.go @@ -153,6 +153,19 @@ func (db *DB) DistinctUsers(ctx context.Context) (int64, error) { return numDistinctUsers, nil } +func (db *DB) UserAlreadyExist(ctx context.Context, userID string) (bool, error) { + var cnt int64 + tx := db.WithContext(ctx).Table("devices").Where("user_id = ?", userID).Count(&cnt) + if tx.Error != nil { + return false, fmt.Errorf("tx.Error: %w", tx.Error) + } + + if cnt > 0 { + return true, nil + } + return false, nil +} + func (db *DB) DumpRequestCreate(ctx context.Context, req *shared.DumpRequest) error { tx := db.WithContext(ctx).Create(req) if tx.Error != nil { @@ -410,7 +423,7 @@ func (db *DB) DeepClean(ctx context.Context) error { FROM devices GROUP BY user_id HAVING COUNT(DISTINCT device_id) = 1 - ) + ) `) if r.Error != nil { return fmt.Errorf("failed to create list of single device users: %w", r.Error) @@ -420,7 +433,7 @@ func (db *DB) DeepClean(ctx context.Context) error { SELECT user_id FROM usage_data WHERE last_used <= (now() - INTERVAL '180 days') - ) + ) `) if r.Error != nil { return fmt.Errorf("failed to create list of inactive users: %w", r.Error) diff --git a/backend/server/internal/server/api_handlers.go b/backend/server/internal/server/api_handlers.go index 0314495..22199df 100644 --- a/backend/server/internal/server/api_handlers.go +++ b/backend/server/internal/server/api_handlers.go @@ -201,19 +201,27 @@ func (s *Server) apiDownloadHandler(w http.ResponseWriter, r *http.Request) { } func (s *Server) apiRegisterHandler(w http.ResponseWriter, r *http.Request) { - if getMaximumNumberOfAllowedUsers() < math.MaxInt { - numDistinctUsers, err := s.db.DistinctUsers(r.Context()) - if err != nil { - panic(fmt.Errorf("db.DistinctUsers: %w", err)) - } - if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { - panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) - } - } userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") isIntegrationTestDevice := getOptionalQueryParam(r, "is_integration_test_device", false) == "true" + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + userAlreadyExist, err := s.db.UserAlreadyExist(r.Context(), userId) + if err != nil { + panic(fmt.Errorf("db.UserAlreadyExist: %w", err)) + } + + if !userAlreadyExist { + numDistinctUsers, err := s.db.DistinctUsers(r.Context()) + if err != nil { + panic(fmt.Errorf("db.DistinctUsers: %w", err)) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } + } + existingDevicesCount, err := s.db.CountDevicesForUser(r.Context(), userId) checkGormError(err) fmt.Printf("apiRegisterHandler: existingDevicesCount=%d\n", existingDevicesCount)