[management] Add GCM encryption and migrate legacy encrypted events (#2569)

* Add AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate legacy encrypted data to AES-GCM encryption

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor and use transaction when migrating data

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add events migration tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip migrating record on error

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preallocate capacity for nonce to avoid allocations in Seal

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga 2024-09-11 20:09:57 +03:00 committed by GitHub
parent c59a39d27d
commit cf6210a6f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 396 additions and 86 deletions

View File

@ -6,6 +6,7 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
) )
@ -13,6 +14,7 @@ var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct { type FieldEncrypt struct {
block cipher.Block block cipher.Block
gcm cipher.AEAD
} }
func GenerateKey() (string, error) { func GenerateKey() (string, error) {
@ -35,14 +37,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{ ec := &FieldEncrypt{
block: block, block: block,
gcm: gcm,
} }
return ec, nil return ec, nil
} }
func (ec *FieldEncrypt) Encrypt(payload string) string { func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload)) plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText)) cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv) cbc := cipher.NewCBCEncrypter(ec.block, iv)
@ -50,7 +59,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText) return base64.StdEncoding.EncodeToString(cipherText)
} }
func (ec *FieldEncrypt) Decrypt(data string) (string, error) { // Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data) cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil { if err != nil {
return "", err return "", err
@ -65,6 +89,27 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil return string(payload), nil
} }
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte { func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding) padText := bytes.Repeat([]byte{byte(padding)}, padding)

View File

@ -15,7 +15,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }
@ -30,6 +34,32 @@ func TestGenerateKey(t *testing.T) {
} }
} }
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) { func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io" testData := "exampl@netbird.io"
key, err := GenerateKey() key, err := GenerateKey()
@ -41,7 +71,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err) t.Fatalf("failed to init email encryption: %s", err)
} }
encrypted := ee.Encrypt(testData) encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" { if encrypted == "" {
t.Fatalf("invalid encrypted text") t.Fatalf("invalid encrypted text")
} }

View File

@ -0,0 +1,157 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
if _, err := db.Exec(createTableQuery); err != nil {
return err
}
if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
return err
}
if err := updateDeletedUsersTable(ctx, db); err != nil {
return fmt.Errorf("failed to update deleted_users table: %v", err)
}
return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
}
// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
exists, err := checkColumnExists(db, "deleted_users", "name")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
}
exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
if err != nil {
return err
}
if !exists {
log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
if err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
}
return nil
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
defer func() {
_ = tx.Rollback()
}()
rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
if err != nil {
return fmt.Errorf("failed to prepare update statement: %v", err)
}
defer updateStmt.Close()
if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %v", err)
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
for rows.Next() {
var (
id, decryptedEmail, decryptedName string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
decryptedEmail, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
continue
}
}
if name != nil {
decryptedName, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
continue
}
}
encryptedEmail, err := crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
encryptedName, err := crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
_, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
if err != nil {
return err
}
}
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,84 @@
package sqlite
import (
"context"
"database/sql"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require"
)
func setupDatabase(t *testing.T) *sql.DB {
t.Helper()
dbFile := filepath.Join(t.TempDir(), eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
require.NoError(t, err, "Failed to open database")
t.Cleanup(func() {
_ = db.Close()
})
_, err = db.Exec(createTableQuery)
require.NoError(t, err, "Failed to create events table")
_, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
require.NoError(t, err, "Failed to create deleted_users table")
return db
}
func TestMigrate(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
legacyName := crypt.LegacyEncrypt("Test Account")
_, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
require.NoError(t, err, "Failed to insert event")
_, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
require.NoError(t, err, "Failed to insert legacy encrypted data")
colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists")
require.False(t, colExists, "enc_algo column should not exist before migration")
err = migrate(context.Background(), crypt, db)
require.NoError(t, err, "Migration failed")
colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
require.True(t, colExists, "enc_algo column should exist after migration")
var encAlgo string
err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
require.NoError(t, err, "Failed to select updated data")
require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
store, err := createStore(crypt, db)
require.NoError(t, err, "Failed to create store")
events, err := store.Get(context.Background(), "accountID", 0, 1, false)
require.NoError(t, err, "Failed to get events")
require.Len(t, events, 1, "Should have one event")
require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
require.Equal(t, "targetID", events[0].TargetID, "target id should match")
require.Equal(t, "accountID", events[0].AccountID, "account id should match")
require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
}

View File

@ -26,7 +26,7 @@ const (
"meta TEXT," + "meta TEXT," +
" target_id TEXT);" " target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events FROM events
@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first. and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/ */
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown" fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com" fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
) )
// Store is the implementation of the activity.Store interface backed by SQLite // Store is the implementation of the activity.Store interface backed by SQLite
@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err return nil, err
} }
_, err = db.Exec(createTableQuery) if err = migrate(ctx, crypt, db); err != nil {
if err != nil {
_ = db.Close() _ = db.Close()
return nil, err return nil, fmt.Errorf("events database migration: %w", err)
} }
_, err = db.Exec(creatTableDeletedUsersQuery) return createStore(crypt, db)
if err != nil {
_ = db.Close()
return nil, err
}
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
}
insertStmt, err := db.Prepare(insertQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
s := &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}
return s, nil
} }
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil return event.Meta, nil
} }
encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) if err != nil {
_, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) return nil, err
}
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
_, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil return nil
} }
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { // createStore initializes and returns a new Store instance with prepared SQL statements.
log.WithContext(ctx).Debugf("check deleted_users table version") func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
rows, err := db.Query(`PRAGMA table_info(deleted_users);`) insertStmt, err := db.Prepare(insertQuery)
if err != nil { if err != nil {
return err _ = db.Close()
return nil, err
}
selectDescStmt, err := db.Prepare(selectDescQuery)
if err != nil {
_ = db.Close()
return nil, err
}
selectAscStmt, err := db.Prepare(selectAscQuery)
if err != nil {
_ = db.Close()
return nil, err
}
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
if err != nil {
_ = db.Close()
return nil, err
}
return &Store{
db: db,
fieldEncrypt: crypt,
insertStatement: insertStmt,
selectDescStatement: selectDescStmt,
selectAscStatement: selectAscStmt,
deleteUserStmt: deleteUserStmt,
}, nil
}
// checkColumnExists checks if a column exists in a specified table
func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
rows, err := db.Query(query)
if err != nil {
return false, fmt.Errorf("failed to query table info: %w", err)
} }
defer rows.Close() defer rows.Close()
found := false
for rows.Next() { for rows.Next() {
var ( var cid int
cid int var name, ctype string
name string var notnull, pk int
dataType string var dfltValue sql.NullString
notNull int
dfltVal sql.NullString err = rows.Scan(&cid, &name, &ctype, &notnull, &dfltValue, &pk)
pk int
)
err := rows.Scan(&cid, &name, &dataType, &notNull, &dfltVal, &pk)
if err != nil { if err != nil {
return err return false, fmt.Errorf("failed to scan row: %w", err)
} }
if name == "name" {
found = true if name == columnName {
break return true, nil
} }
} }
err = rows.Err() if err = rows.Err(); err != nil {
if err != nil { return false, err
return err
} }
if found { return false, nil
return nil
}
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
} }