mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 14:03:35 +01:00
update transaction logic
This commit is contained in:
parent
adf521a9d9
commit
e3f3d2c1bd
@ -20,6 +20,11 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@ -36,10 +41,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -846,17 +847,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||
// newly groups to create and an error if any occurred.
|
||||
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
||||
existedGroupsByName := make(map[string]*nbgroup.Group)
|
||||
for _, group := range groups {
|
||||
existedGroupsByName[group.Name] = group
|
||||
@ -880,7 +871,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID
|
||||
if !exists {
|
||||
group = &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
AccountID: user.AccountID,
|
||||
Name: name,
|
||||
Issued: nbgroup.GroupIssuedJWT,
|
||||
}
|
||||
@ -1836,16 +1827,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// skip update if no changes
|
||||
if !hasChanges {
|
||||
return nil
|
||||
}
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlockPeer != nil {
|
||||
@ -1853,19 +1834,39 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
}()
|
||||
|
||||
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
|
||||
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
|
||||
|
||||
var addNewGroups []string
|
||||
var removeOldGroups []string
|
||||
var hasChanges bool
|
||||
var user *User
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, user, groups, jwtGroupsNames)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting JWT groups changes: %w", err)
|
||||
}
|
||||
|
||||
hasChanges = changed
|
||||
// skip update if no changes
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
|
||||
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
|
||||
|
||||
user.AutoGroups = updatedAutoGroups
|
||||
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||
return fmt.Errorf("error saving user: %w", err)
|
||||
@ -1873,7 +1874,22 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
|
||||
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
|
||||
updatedGroups, err := am.updateUserPeersInGroups(ctx, groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
@ -1895,6 +1911,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return err
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
if err != nil {
|
||||
|
@ -1185,3 +1185,37 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||
}
|
||||
|
||||
func TestSqlite_CreateAndGetObjcetInTransaction(t *testing.T) {
|
||||
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
group := &nbgroup.Group{
|
||||
ID: "group-id",
|
||||
AccountID: "account-id",
|
||||
Name: "group-name",
|
||||
Issued: "api",
|
||||
Peers: nil,
|
||||
}
|
||||
store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
|
||||
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
}
|
||||
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get group")
|
||||
return err
|
||||
}
|
||||
|
||||
t.Logf("group: %v", group)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@ -15,7 +17,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -1255,36 +1256,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
|
||||
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
|
||||
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
|
||||
|
||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userPeerIDMap := make(map[string]struct{}, len(peers))
|
||||
for _, peer := range peers {
|
||||
userPeerIDMap[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, gid := range groupsToAdd {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
group, ok := accountGroups[gid]
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
addUserPeersToGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
|
||||
for _, gid := range groupsToRemove {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
group, ok := accountGroups[gid]
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
removeUserPeersFromGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
|
Loading…
Reference in New Issue
Block a user