diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go index 6061676c5..30510bb8e 100644 --- a/internal/db/bundb/account.go +++ b/internal/db/bundb/account.go @@ -25,11 +25,13 @@ "strings" "time" + "github.com/sirupsen/logrus" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" ) type accountDB struct { @@ -268,14 +270,24 @@ func (a *accountDB) GetAccountStatuses(ctx context.Context, accountID string, li if mediaOnly { // attachments are stored as a json object; // this implementation differs between sqlite and postgres, - // so we have to be very thorough to cover all eventualities + // so we have to be thorough to cover all eventualities q = q.WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { - return q. - Where("? IS NOT NULL", bun.Ident("attachments")). - Where("? != ''", bun.Ident("attachments")). - Where("? != 'null'", bun.Ident("attachments")). - Where("? != '{}'", bun.Ident("attachments")). - Where("? != '[]'", bun.Ident("attachments")) + switch a.conn.Dialect().Name() { + case dialect.PG: + return q. + Where("? IS NOT NULL", bun.Ident("attachments")). + Where("? != '{}'", bun.Ident("attachments")) + case dialect.SQLite: + return q. + Where("? IS NOT NULL", bun.Ident("attachments")). + Where("? != ''", bun.Ident("attachments")). + Where("? != 'null'", bun.Ident("attachments")). + Where("? != '{}'", bun.Ident("attachments")). + Where("? != '[]'", bun.Ident("attachments")) + default: + logrus.Panic("db dialect was neither pg nor sqlite") + return q + } }) } diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go index 33c6b7ee6..134e38940 100644 --- a/internal/db/bundb/account_test.go +++ b/internal/db/bundb/account_test.go @@ -40,6 +40,12 @@ func (suite *AccountTestSuite) TestGetAccountStatuses() { suite.Len(statuses, 5) } +func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() { + statuses, err := suite.db.GetAccountStatuses(context.Background(), suite.testAccounts["local_account_1"].ID, 20, false, false, "", "", false, true, false) + suite.NoError(err) + suite.Len(statuses, 1) +} + func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() { account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID) if err != nil { diff --git a/internal/db/bundb/bundbnew_test.go b/internal/db/bundb/bundbnew_test.go index d5e413a4f..2bd945864 100644 --- a/internal/db/bundb/bundbnew_test.go +++ b/internal/db/bundb/bundbnew_test.go @@ -41,6 +41,7 @@ func (suite *BundbNewTestSuite) TestCreateNewDB() { func (suite *BundbNewTestSuite) TestCreateNewSqliteDBNoAddress() { // create a new db with no address specified config.SetDbAddress("") + config.SetDbType("sqlite") db, err := bundb.NewBunDBService(context.Background()) suite.EqualError(err, "'db-address' was not set when attempting to start sqlite") suite.Nil(db)