Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-09 01:17:01 +03:00
parent 6dc185e141
commit bdeb95c58c
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
3 changed files with 187 additions and 247 deletions

View File

@ -2126,12 +2126,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.GroupsPropagationEnabled {
removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups)
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}

View File

@ -79,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
@ -89,66 +89,35 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return status.NewUserNotPartOfAccountError()
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs)
if err != nil {
return err
}
var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
})
if err != nil {
return err
@ -166,13 +135,13 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
@ -184,36 +153,34 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}
for _, peerID := range addedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err)
continue
}
peerCopy := peer // copy to avoid closure issues
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
})
}
for _, peerID := range removedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err)
continue
}
peerCopy := peer // copy to avoid closure issues
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
})
}
@ -246,28 +213,27 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
return status.NewUserNotPartOfAccountError()
}
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
var group *nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
})
if err != nil {
return err
@ -279,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
}
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
@ -289,36 +260,31 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
return status.NewUserNotPartOfAccountError()
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
for _, groupID := range groupIDs {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
continue
}
if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
return err
@ -333,36 +299,30 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
@ -377,38 +337,30 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
updated := false
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
updated = true
break
}
}
if !updated {
return nil
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
if err != nil {
return err
@ -421,10 +373,43 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return nil
}
func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return status.Errorf(status.NotFound, "user not found")
}
@ -433,27 +418,27 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group
}
}
if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked {
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked {
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked {
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked {
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked {
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
@ -462,7 +447,7 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
@ -477,8 +462,8 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
@ -494,8 +479,8 @@ func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accou
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
@ -512,8 +497,8 @@ func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, acco
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
@ -531,8 +516,8 @@ func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, account
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
@ -547,8 +532,8 @@ func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, ac
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
@ -563,12 +548,12 @@ func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accoun
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
@ -577,13 +562,13 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context,
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked {
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked {
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked {
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
@ -591,40 +576,6 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context,
return false, nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
@ -634,22 +585,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false
}

View File

@ -331,7 +331,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
if err != nil {
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
@ -594,9 +597,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil {
return nil, nil, nil, err
}
if newGroupsAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
@ -1033,12 +1041,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
}