update transaction logic

This commit is contained in:
Pascal Fischer 2024-10-04 15:17:28 +02:00
parent adf521a9d9
commit e3f3d2c1bd
3 changed files with 102 additions and 52 deletions

View File

@ -20,6 +20,11 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@ -36,10 +41,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
const (
@ -846,17 +847,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred.
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return false, nil, nil, err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
}
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range groups {
existedGroupsByName[group.Name] = group
@ -880,7 +871,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID
if !exists {
group = &nbgroup.Group{
ID: xid.New().String(),
AccountID: accountID,
AccountID: user.AccountID,
Name: name,
Issued: nbgroup.GroupIssuedJWT,
}
@ -1836,16 +1827,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil {
return err
}
// skip update if no changes
if !hasChanges {
return nil
}
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlockPeer != nil {
@ -1853,19 +1834,39 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
}()
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
var addNewGroups []string
var removeOldGroups []string
var hasChanges bool
var user *User
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
hasChanges = changed
// skip update if no changes
if !changed {
return nil
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
@ -1873,7 +1874,22 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
groups, err = transaction.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(ctx, groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
@ -1895,6 +1911,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return err
}
if !hasChanges {
return nil
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {

View File

@ -1185,3 +1185,37 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}
func TestSqlite_CreateAndGetObjcetInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
}

View File

@ -8,6 +8,8 @@ import (
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
@ -15,7 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
@ -1255,36 +1256,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)