diff --git a/management/server/account.go b/management/server/account.go index d5e8c8cf8..1463ae033 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -96,7 +96,6 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -478,12 +477,12 @@ func (a *Account) GetPeerNetworkMap( } nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } @@ -1860,7 +1859,9 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + am.updateAccountPeers(ctx, account) + } } for _, g := range addNewGroups { diff --git a/management/server/peer.go b/management/server/peer.go index 97e11c08a..e358483af 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "strings" "sync" "time" @@ -185,7 +186,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - if peer.Name != update.Name { + peerLabelUpdated := peer.Name != update.Name + + if peerLabelUpdated { peer.Name = update.Name existingLabels := account.getPeerDNSLabels() @@ -226,7 +229,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - am.updateAccountPeers(ctx, account) + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) { + am.updateAccountPeers(ctx, account) + } return peer, nil } @@ -290,6 +296,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + updateAccountPeers := isPeerInActiveGroup(account, peerID) + err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err @@ -300,7 +308,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -390,9 +400,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer + var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - var groupsToAdd []string var setupKeyID string var setupKeyName string var ephemeral bool @@ -543,6 +553,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } + if areGroupChangesAffectPeers(account, groupsToAdd) { + am.updateAccountPeers(ctx, account) + } + am.updateAccountPeers(ctx, account) approvedPeersMap, err := am.GetValidatedPeers(account) @@ -859,51 +873,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -// UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if sshKey == "" { - log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) - return nil - } - - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, account.Id) - if err != nil { - return err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - if peer.SSHKey == sshKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) - return nil - } - - peer.SSHKey = sshKey - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return nil -} - // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1010,3 +979,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { } return labelMap } + +// IsPeerInActiveGroup checks if the given peer is part of a group that is used +// in an active DNS, route, or ACL configuration. +func isPeerInActiveGroup(account *Account, peerID string) bool { + peerGroupIDs := make([]string, 0) + for _, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + peerGroupIDs = append(peerGroupIDs, group.ID) + } + } + return areGroupChangesAffectPeers(account, peerGroupIDs) +} diff --git a/management/server/policy.go b/management/server/policy.go index 75647de44..95bae8973 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() { } } +// ruleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) ruleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + // FirewallRule is a rule of the firewall. type FirewallRule struct { // PeerIP of the peer @@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = am.savePolicy(account, policy, isUpdate); err != nil { + updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if err != nil { return err } @@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) // savePolicy saves or updates a policy in the given account. // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { for index, rule := range policyToSave.Rules { rule.Sources = filterValidGroupIDs(account, rule.Sources) rule.Destinations = filterValidGroupIDs(account, rule.Destinations) @@ -442,18 +457,22 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if isUpdate { policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) if policyIdx < 0 { - return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + oldPolicy := account.Policies[policyIdx] + updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + // Update the existing policy account.Policies[policyIdx] = policyToSave - return nil + + return updateAccountPeers, nil } // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return nil + return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/route.go b/management/server/route.go index 39ee6170c..1cf00b37c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, &newRoute) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + oldRoute := account.Routes[routeToSave.ID] account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() @@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, routy) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { } return &portInfo } + +// isRouteChangeAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers +func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +}