Remove condition check for network serial update

This commit is contained in:
bcmmbaga 2024-07-20 20:36:36 +03:00
parent f5ec234f09
commit bb08adcbac
No known key found for this signature in database
GPG Key ID: 7249A19D20613553
7 changed files with 51 additions and 65 deletions

View File

@ -1724,17 +1724,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
updateAccountPeers := areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups)
if updateAccountPeers {
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
if updateAccountPeers {
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
am.updateAccountPeers(ctx, account)
}
unlock()

View File

@ -92,11 +92,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers := (areGroupChangesAffectPeers(account, addedGroups) && anyGroupHasPeers(account, addedGroups)) ||
areGroupChangesAffectPeers(account, removedGroups) && anyGroupHasPeers(account, removedGroups)
if updateAccountPeers {
account.Network.IncSerial()
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@ -113,6 +109,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
updateAccountPeers := (areGroupChangesAffectPeers(account, addedGroups) && anyGroupHasPeers(account, addedGroups)) ||
areGroupChangesAffectPeers(account, removedGroups) && anyGroupHasPeers(account, removedGroups)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
}

View File

@ -170,16 +170,12 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
updateAccountPeers := areGroupChangesAffectPeers(account, newGroupIDs)
if updateAccountPeers {
account.Network.IncSerial()
}
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
return err
}
if updateAccountPeers {
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
@ -329,16 +325,12 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID)
}
updateAccountPeers := areGroupChangesAffectPeers(account, []string{group.ID})
if updateAccountPeers {
account.Network.IncSerial()
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if updateAccountPeers {
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
}
@ -360,11 +352,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
updateAccountPeers := areGroupChangesAffectPeers(account, []string{group.ID})
if updateAccountPeers {
account.Network.IncSerial()
}
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
@ -374,7 +362,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
}
if updateAccountPeers {
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
}

View File

@ -79,16 +79,12 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
account.NameServerGroups[newNSGroup.ID] = newNSGroup
updateAccountPeers := anyGroupHasPeers(account, newNSGroup.Groups)
if updateAccountPeers {
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if updateAccountPeers {
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@ -116,17 +112,14 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
}
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
updateAccountPeers := anyGroupHasPeers(account, nsGroupToSave.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
if updateAccountPeers {
account.Network.IncSerial()
}
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if updateAccountPeers {
if anyGroupHasPeers(account, nsGroupToSave.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@ -151,16 +144,12 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
delete(account.NameServerGroups, nsGroupID)
updateAccountPeers := anyGroupHasPeers(account, nsGroup.Groups)
if updateAccountPeers {
account.Network.IncSerial()
}
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if updateAccountPeers {
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())

View File

@ -212,16 +212,14 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
account.UpdatePeer(peer)
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
if expired && peer.LoginExpirationEnabled {
account.Network.IncSerial()
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
if expired && peer.LoginExpirationEnabled {
am.updateAccountPeers(ctx, account)
}
@ -501,11 +499,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
account.Peers[newPeer.ID] = newPeer
updateAccountPeers := areGroupChangesAffectPeers(account, groupsToAdd)
if updateAccountPeers {
account.Network.IncSerial()
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
@ -523,7 +517,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if updateAccountPeers {
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}

View File

@ -191,6 +191,16 @@ func (p *Policy) UpgradeAndFix() {
}
}
// AreAllRuleGroupsEmpty checks if all rule groups in the policy are effectively empty.
func (p *Policy) AreAllRuleGroupsEmpty() bool {
for _, rule := range p.Rules {
if len(rule.Sources) != 0 && len(rule.Destinations) != 0 {
return false
}
}
return true
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
@ -364,7 +374,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if !policy.AreAllRuleGroupsEmpty() {
am.updateAccountPeers(ctx, account)
}
return nil
}
@ -391,7 +403,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if !policy.AreAllRuleGroupsEmpty() {
am.updateAccountPeers(ctx, account)
}
return nil
}
@ -561,3 +575,13 @@ func getPostureChecks(account *Account, postureChecksID string) *posture.Checks
}
return nil
}
// isPolicyRuleGroupsEmpty checks if a given policy has rules with empty sources and destinations.
func isPolicyRuleGroupsEmpty(policy *Policy) bool {
for _, rule := range policy.Rules {
if len(rule.Sources) != 0 && len(rule.Destinations) != 0 {
return false
}
}
return true
}

View File

@ -70,23 +70,20 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
}
updateAccountPeers := false
action := activity.PostureCheckCreated
if exists {
action = activity.PostureCheckUpdated
}
updateAccountPeers, _ = isPostureCheckLinkedToPolicy(account, postureChecks.ID)
if updateAccountPeers {
account.Network.IncSerial()
}
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
updateAccountPeers, _ := isPostureCheckLinkedToPolicy(account, postureChecks.ID)
if exists && updateAccountPeers {
am.updateAccountPeers(ctx, account)
}