fix refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-11 12:38:34 +03:00
parent 174e07fefd
commit d54b6967ce
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 52 additions and 20 deletions

View File

@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
} }
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }

View File

@ -576,8 +576,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil return false, nil
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers. func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() { if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true return true
@ -585,3 +584,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
} }
return false return false
} }
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return false, err
}
if group.HasPeers() {
return true, nil
}
}
return false, nil
}

View File

@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err return nil, err
} }
if anyGroupHasPeers(account, newNSGroup.Groups) { if am.anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err return err
} }
if anyGroupHasPeers(account, nsGroup.Groups) { if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
@ -279,9 +279,9 @@ func validateDomain(domain string) error {
} }
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled { if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false return false
} }
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
} }

View File

@ -613,7 +613,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks := am.getPeerPostureChecks(account, newPeer) postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
@ -695,7 +699,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks = am.getPeerPostureChecks(account, peer)
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@ -868,7 +876,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks = am.getPeerPostureChecks(account, peer)
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@ -1021,7 +1033,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done() defer wg.Done()
defer func() { <-semaphore }() defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p) postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})

View File

@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) { if am.anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if !policyToSave.Enabled && !oldPolicy.Enabled { if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil return false, nil
} }
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil return updateAccountPeers, nil
} }
@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
// Add the new policy to the account // Add the new policy to the account
account.Policies = append(account.Policies, policyToSave) account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
} }
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {

View File

@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err return nil, err
} }
if isRouteChangeAffectPeers(account, &newRoute) { if am.isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) { if am.isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
// isRouteChangeAffectPeers checks if a given route affects peers by determining // isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
} }