mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-12 17:08:36 +01:00
[management] Add GCM encryption and migrate legacy encrypted events (#2569)
* Add AES-GCM encryption Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * migrate legacy encrypted data to AES-GCM encryption Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor and use transaction when migrating data Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Add events migration tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * skip migrating record on error Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Preallocate capacity for nonce to avoid allocations in Seal Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
c59a39d27d
commit
cf6210a6f4
@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
|||||||
|
|
||||||
type FieldEncrypt struct {
|
type FieldEncrypt struct {
|
||||||
block cipher.Block
|
block cipher.Block
|
||||||
|
gcm cipher.AEAD
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateKey() (string, error) {
|
func GenerateKey() (string, error) {
|
||||||
@ -35,14 +37,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
ec := &FieldEncrypt{
|
ec := &FieldEncrypt{
|
||||||
block: block,
|
block: block,
|
||||||
|
gcm: gcm,
|
||||||
}
|
}
|
||||||
|
|
||||||
return ec, nil
|
return ec, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ec *FieldEncrypt) Encrypt(payload string) string {
|
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
|
||||||
plainText := pkcs5Padding([]byte(payload))
|
plainText := pkcs5Padding([]byte(payload))
|
||||||
cipherText := make([]byte, len(plainText))
|
cipherText := make([]byte, len(plainText))
|
||||||
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||||
@ -50,7 +59,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
|
|||||||
return base64.StdEncoding.EncodeToString(cipherText)
|
return base64.StdEncoding.EncodeToString(cipherText)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
// Encrypt encrypts plaintext using AES-GCM
|
||||||
|
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
|
||||||
|
plaintext := []byte(payload)
|
||||||
|
nonceSize := ec.gcm.NonceSize()
|
||||||
|
|
||||||
|
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
|
||||||
|
|
||||||
|
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
|
||||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -65,6 +89,27 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
|||||||
return string(payload), nil
|
return string(payload), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts ciphertext using AES-GCM
|
||||||
|
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||||
|
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := ec.gcm.NonceSize()
|
||||||
|
if len(cipherText) < nonceSize {
|
||||||
|
return "", errors.New("cipher text too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||||
|
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plainText), nil
|
||||||
|
}
|
||||||
|
|
||||||
func pkcs5Padding(ciphertext []byte) []byte {
|
func pkcs5Padding(ciphertext []byte) []byte {
|
||||||
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||||
|
@ -15,7 +15,11 @@ func TestGenerateKey(t *testing.T) {
|
|||||||
t.Fatalf("failed to init email encryption: %s", err)
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encrypted := ee.Encrypt(testData)
|
encrypted, err := ee.Encrypt(testData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to encrypt data: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if encrypted == "" {
|
if encrypted == "" {
|
||||||
t.Fatalf("invalid encrypted text")
|
t.Fatalf("invalid encrypted text")
|
||||||
}
|
}
|
||||||
@ -30,6 +34,32 @@ func TestGenerateKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateKeyLegacy(t *testing.T) {
|
||||||
|
testData := "exampl@netbird.io"
|
||||||
|
key, err := GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate key: %s", err)
|
||||||
|
}
|
||||||
|
ee, err := NewFieldEncrypt(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := ee.LegacyEncrypt(testData)
|
||||||
|
if encrypted == "" {
|
||||||
|
t.Fatalf("invalid encrypted text")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, err := ee.LegacyDecrypt(encrypted)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to decrypt data: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decrypted != testData {
|
||||||
|
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCorruptKey(t *testing.T) {
|
func TestCorruptKey(t *testing.T) {
|
||||||
testData := "exampl@netbird.io"
|
testData := "exampl@netbird.io"
|
||||||
key, err := GenerateKey()
|
key, err := GenerateKey()
|
||||||
@ -41,7 +71,11 @@ func TestCorruptKey(t *testing.T) {
|
|||||||
t.Fatalf("failed to init email encryption: %s", err)
|
t.Fatalf("failed to init email encryption: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
encrypted := ee.Encrypt(testData)
|
encrypted, err := ee.Encrypt(testData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to encrypt data: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if encrypted == "" {
|
if encrypted == "" {
|
||||||
t.Fatalf("invalid encrypted text")
|
t.Fatalf("invalid encrypted text")
|
||||||
}
|
}
|
||||||
|
157
management/server/activity/sqlite/migration.go
Normal file
157
management/server/activity/sqlite/migration.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||||
|
if _, err := db.Exec(createTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := updateDeletedUsersTable(ctx, db); err != nil {
|
||||||
|
return fmt.Errorf("failed to update deleted_users table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
|
||||||
|
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||||
|
exists, err := checkColumnExists(db, "deleted_users", "name")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
|
||||||
|
|
||||||
|
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
|
||||||
|
|
||||||
|
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
|
||||||
|
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||||
|
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
|
||||||
|
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to begin transaction: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
|
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute select query: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to prepare update statement: %v", err)
|
||||||
|
}
|
||||||
|
defer updateStmt.Close()
|
||||||
|
|
||||||
|
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("failed to commit transaction: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
|
||||||
|
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
id, decryptedEmail, decryptedName string
|
||||||
|
email, name *string
|
||||||
|
)
|
||||||
|
|
||||||
|
err := rows.Scan(&id, &email, &name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if email != nil {
|
||||||
|
decryptedEmail, err = crypt.LegacyDecrypt(*email)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||||
|
id,
|
||||||
|
fmt.Errorf("failed to decrypt email: %w", err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if name != nil {
|
||||||
|
decryptedName, err = crypt.LegacyDecrypt(*name)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
|
||||||
|
id,
|
||||||
|
fmt.Errorf("failed to decrypt name: %w", err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt email: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedName, err := crypt.Encrypt(decryptedName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encrypt name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
84
management/server/activity/sqlite/migration_test.go
Normal file
84
management/server/activity/sqlite/migration_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupDatabase(t *testing.T) *sql.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
|
||||||
|
db, err := sql.Open("sqlite3", dbFile)
|
||||||
|
require.NoError(t, err, "Failed to open database")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = db.Exec(createTableQuery)
|
||||||
|
require.NoError(t, err, "Failed to create events table")
|
||||||
|
|
||||||
|
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
|
||||||
|
require.NoError(t, err, "Failed to create deleted_users table")
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrate(t *testing.T) {
|
||||||
|
db := setupDatabase(t)
|
||||||
|
|
||||||
|
key, err := GenerateKey()
|
||||||
|
require.NoError(t, err, "Failed to generate key")
|
||||||
|
|
||||||
|
crypt, err := NewFieldEncrypt(key)
|
||||||
|
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||||
|
|
||||||
|
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
|
||||||
|
legacyName := crypt.LegacyEncrypt("Test Account")
|
||||||
|
|
||||||
|
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
|
||||||
|
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
|
||||||
|
require.NoError(t, err, "Failed to insert event")
|
||||||
|
|
||||||
|
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
|
||||||
|
require.NoError(t, err, "Failed to insert legacy encrypted data")
|
||||||
|
|
||||||
|
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
|
||||||
|
require.NoError(t, err, "Failed to check if enc_algo column exists")
|
||||||
|
require.False(t, colExists, "enc_algo column should not exist before migration")
|
||||||
|
|
||||||
|
err = migrate(context.Background(), crypt, db)
|
||||||
|
require.NoError(t, err, "Migration failed")
|
||||||
|
|
||||||
|
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
|
||||||
|
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
|
||||||
|
require.True(t, colExists, "enc_algo column should exist after migration")
|
||||||
|
|
||||||
|
var encAlgo string
|
||||||
|
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
|
||||||
|
require.NoError(t, err, "Failed to select updated data")
|
||||||
|
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
|
||||||
|
|
||||||
|
store, err := createStore(crypt, db)
|
||||||
|
require.NoError(t, err, "Failed to create store")
|
||||||
|
|
||||||
|
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
|
||||||
|
require.NoError(t, err, "Failed to get events")
|
||||||
|
|
||||||
|
require.Len(t, events, 1, "Should have one event")
|
||||||
|
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
|
||||||
|
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
|
||||||
|
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
|
||||||
|
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
|
||||||
|
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
|
||||||
|
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
|
||||||
|
}
|
@ -26,7 +26,7 @@ const (
|
|||||||
"meta TEXT," +
|
"meta TEXT," +
|
||||||
" target_id TEXT);"
|
" target_id TEXT);"
|
||||||
|
|
||||||
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
|
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
|
||||||
|
|
||||||
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
|
||||||
FROM events
|
FROM events
|
||||||
@ -69,10 +69,12 @@ const (
|
|||||||
and some selfhosted deployments might have duplicates already so we need to clean the table first.
|
and some selfhosted deployments might have duplicates already so we need to clean the table first.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
|
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
|
||||||
|
|
||||||
fallbackName = "unknown"
|
fallbackName = "unknown"
|
||||||
fallbackEmail = "unknown@unknown.com"
|
fallbackEmail = "unknown@unknown.com"
|
||||||
|
|
||||||
|
gcmEncAlgo = "GCM"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||||
@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(createTableQuery)
|
if err = migrate(ctx, crypt, db); err != nil {
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
_ = db.Close()
|
||||||
return nil, err
|
return nil, fmt.Errorf("events database migration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(creatTableDeletedUsersQuery)
|
return createStore(crypt, db)
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = updateDeletedUsersTable(ctx, db)
|
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
insertStmt, err := db.Prepare(insertQuery)
|
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
selectDescStmt, err := db.Prepare(selectDescQuery)
|
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
selectAscStmt, err := db.Prepare(selectAscQuery)
|
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
|
||||||
if err != nil {
|
|
||||||
_ = db.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Store{
|
|
||||||
db: db,
|
|
||||||
fieldEncrypt: crypt,
|
|
||||||
insertStatement: insertStmt,
|
|
||||||
selectDescStatement: selectDescStmt,
|
|
||||||
selectAscStatement: selectAscStmt,
|
|
||||||
deleteUserStmt: deleteUserStmt,
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
||||||
@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
|
|||||||
return event.Meta, nil
|
return event.Meta, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||||
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
if err != nil {
|
||||||
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
|
return nil, err
|
||||||
|
}
|
||||||
|
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
// createStore initializes and returns a new Store instance with prepared SQL statements.
|
||||||
log.WithContext(ctx).Debugf("check deleted_users table version")
|
func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
|
||||||
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
|
insertStmt, err := db.Prepare(insertQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
selectDescStmt, err := db.Prepare(selectDescQuery)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
selectAscStmt, err := db.Prepare(selectAscQuery)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Store{
|
||||||
|
db: db,
|
||||||
|
fieldEncrypt: crypt,
|
||||||
|
insertStatement: insertStmt,
|
||||||
|
selectDescStatement: selectDescStmt,
|
||||||
|
selectAscStatement: selectAscStmt,
|
||||||
|
deleteUserStmt: deleteUserStmt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkColumnExists checks if a column exists in a specified table
|
||||||
|
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
|
||||||
|
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to query table info: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
found := false
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var (
|
var cid int
|
||||||
cid int
|
var name, ctype string
|
||||||
name string
|
var notnull, pk int
|
||||||
dataType string
|
var dfltValue sql.NullString
|
||||||
notNull int
|
|
||||||
dfltVal sql.NullString
|
err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)
|
||||||
pk int
|
|
||||||
)
|
|
||||||
err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, fmt.Errorf("failed to scan row: %w", err)
|
||||||
}
|
}
|
||||||
if name == "name" {
|
|
||||||
found = true
|
if name == columnName {
|
||||||
break
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rows.Err()
|
if err = rows.Err(); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if found {
|
return false, nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("update delted_users table")
|
|
||||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user