From 9ed325e0a5d62f82d2d85c05dfa5d492870e19dd Mon Sep 17 00:00:00 2001 From: David Dworken Date: Sun, 11 Dec 2022 19:42:51 -0800 Subject: [PATCH] Add support for limiting the number of registrations to fix #46 --- README.md | 5 ++++- backend/server/server.go | 24 ++++++++++++++++++++++++ backend/server/server_test.go | 22 ++++++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1159b49..0d6b4f3 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,10 @@ But if you'd like to self-host the hishtory backend, you can! The backend is a s Check out the [`docker-compose.yml`](https://github.com/ddworken/hishtory/blob/master/backend/server/docker-compose.yml) file for an example config to start a hiSHtory server using postgres. -If you want to use a SQLite backend, you can do so by setting the `HISHTORY_SQLITE_DB` environment variable to point to a file. It will then create a SQLite DB at the given location. +A few configuration options: + +* If you want to use a SQLite backend, you can do so by setting the `HISHTORY_SQLITE_DB` environment variable to point to a file. It will then create a SQLite DB at the given location. +* If you want to limit the number of users that your server allows (e.g. because you only intend to use the server for yourself), you can set the environment variable `HISHTORY_MAX_NUM_USERS=1` (or to whatever value you wish for the limit to be). Leave it unset to allow registrations with no cap.
diff --git a/backend/server/server.go b/backend/server/server.go index 9d8a676..8d89373 100644 --- a/backend/server/server.go +++ b/backend/server/server.go @@ -7,6 +7,7 @@ import ( "html" "io" "log" + "math" "net/http" "os" "reflect" @@ -271,6 +272,17 @@ func getRemoteAddr(r *http.Request) string { } func apiRegisterHandler(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if getMaximumNumberOfAllowedUsers() < math.MaxInt { + row := GLOBAL_DB.WithContext(ctx).Raw("SELECT COUNT(DISTINCT devices.user_id) FROM devices").Row() + var numDistinctUsers int64 = 0 + err := row.Scan(&numDistinctUsers) + if err != nil { + panic(err) + } + if numDistinctUsers >= int64(getMaximumNumberOfAllowedUsers()) { + panic(fmt.Sprintf("Refusing to allow registration of new device since there are currently %d users and this server allows a max of %d users", numDistinctUsers, getMaximumNumberOfAllowedUsers())) + } + } userId := getRequiredQueryParam(r, "user_id") deviceId := getRequiredQueryParam(r, "device_id") var existingDevicesCount int64 = -1 @@ -852,4 +864,16 @@ func checkGormResult(result *gorm.DB) { } } +func getMaximumNumberOfAllowedUsers() int { + maxNumUsersStr := os.Getenv("HISHTORY_MAX_NUM_USERS") + if maxNumUsersStr == "" { + return math.MaxInt + } + maxNumUsers, err := strconv.Atoi(maxNumUsersStr) + if err != nil { + return math.MaxInt + } + return maxNumUsers +} + // TODO(optimization): Maybe optimize the endpoints a bit to reduce the number of round trips required? diff --git a/backend/server/server_test.go b/backend/server/server_test.go index ac89dcc..d20a30b 100644 --- a/backend/server/server_test.go +++ b/backend/server/server_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" @@ -521,6 +522,27 @@ func TestHealthcheck(t *testing.T) { assertNoLeakedConnections(t, GLOBAL_DB) } +func TestLimitRegistrations(t *testing.T) { + // Set up + InitDB() + defer testutils.BackupAndRestoreEnv("HISHTORY_MAX_NUM_USERS")() + os.Setenv("HISHTORY_MAX_NUM_USERS", "2") + + // Register three devices across two users + deviceReq := httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) + apiRegisterHandler(context.Background(), nil, deviceReq) + deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user1"), nil) + apiRegisterHandler(context.Background(), nil, deviceReq) + deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user2"), nil) + apiRegisterHandler(context.Background(), nil, deviceReq) + + // And this next one should fail since it is a new user + defer func() { _ = recover() }() + deviceReq = httptest.NewRequest(http.MethodGet, "/?device_id="+uuid.Must(uuid.NewRandom()).String()+"&user_id="+data.UserId("user3"), nil) + apiRegisterHandler(context.Background(), nil, deviceReq) + t.Errorf("expected panic") +} + func assertNoLeakedConnections(t *testing.T, db *gorm.DB) { sqlDB, err := db.DB() if err != nil {