fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-06 20:34:50 +03:00
parent 716009b791
commit 63c510e80d
4 changed files with 80 additions and 66 deletions

View File

@ -96,7 +96,6 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@ -478,12 +477,12 @@ func (a *Account) GetPeerNetworkMap(
}
nm := &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: routesFirewallRules,
}
@ -1860,7 +1859,9 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
am.updateAccountPeers(ctx, account)
}
}
for _, g := range addNewGroups {

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"slices"
"strings"
"sync"
"time"
@ -185,7 +186,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peer.Name != update.Name {
peerLabelUpdated := peer.Name != update.Name
if peerLabelUpdated {
peer.Name = update.Name
existingLabels := account.getPeerDNSLabels()
@ -226,7 +229,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err
}
am.updateAccountPeers(ctx, account)
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) {
am.updateAccountPeers(ctx, account)
}
return peer, nil
}
@ -290,6 +296,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
return err
@ -300,7 +308,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
}
return nil
}
@ -390,9 +400,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
@ -543,6 +553,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@ -859,51 +873,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -1010,3 +979,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
}
return labelMap
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
}

View File

@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
}
}
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
if err = am.savePolicy(account, policy, isUpdate); err != nil {
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if err != nil {
return err
}
@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
}
return nil
}
@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
@ -442,18 +457,22 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
oldPolicy := account.Policies[policyIdx]
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
return updateAccountPeers, nil
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return nil
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {

View File

@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
am.updateAccountPeers(ctx, account)
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
}
return nil
}
@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
}
return &portInfo
}
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}