Copy account when storing to avoid reference issues

This commit is contained in:
braginini
2024-04-17 17:03:21 +02:00
parent e7a6483912
commit a75f982fcd

View File

@ -150,7 +150,7 @@ func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
} }
// Use reflect.Slice to get a slice of the records for the current batch // Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface() batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Error; err != nil { if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err return err
} }
} }
@ -160,51 +160,65 @@ func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
func (s *SqliteStore) SaveAccount(account *Account) error { func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now() start := time.Now()
for _, key := range account.SetupKeys { // operate over a fresh copy as we will modify its fields
account.SetupKeysG = append(account.SetupKeysG, *key) accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
} }
for id, peer := range account.Peers { accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id peer.ID = id
account.PeersG = append(account.PeersG, *peer) peer.AccountID = account.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
} }
for id, user := range account.Users { accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id user.Id = id
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs { for id, pat := range user.PATs {
pat.ID = id pat.ID = id
user.PATsG = append(user.PATsG, *pat) user.PATsG = append(user.PATsG, *pat)
} }
account.UsersG = append(account.UsersG, *user) accCopy.UsersG = append(accCopy.UsersG, *user)
} }
for id, group := range account.Groups { accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id group.ID = id
account.GroupsG = append(account.GroupsG, *group) group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
} }
for id, route := range account.Routes { accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id route.ID = id
account.RoutesG = append(account.RoutesG, *route) route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
} }
for id, ns := range account.NameServerGroups { accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
} }
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
result = tx.Select(clause.Associations).Delete(account) result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
@ -213,36 +227,36 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
Session(&gorm.Session{FullSaveAssociations: true}). Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}). Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG"). Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG").
Create(account) Create(accCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
const batchSize = 500 const batchSize = 500
err := batchInsert(account.PeersG, batchSize, tx) err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil { if err != nil {
return err return err
} }
err = batchInsert(account.UsersG, batchSize, tx) err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil { if err != nil {
return err return err
} }
err = batchInsert(account.GroupsG, batchSize, tx) err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil { if err != nil {
return err return err
} }
err = batchInsert(account.RoutesG, batchSize, tx) err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil { if err != nil {
return err return err
} }
return batchInsert(account.SetupKeysG, batchSize, tx) return batchInsert(accCopy.SetupKeysG, batchSize, tx)
}) })
took := time.Since(start) took := time.Since(start)
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took) s.metrics.StoreMetrics().CountPersistenceDuration(took)
} }
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), account.Id) log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err return err
} }