fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-06 20:34:50 +03:00
parent 716009b791
commit 63c510e80d
4 changed files with 80 additions and 66 deletions

View File

@ -96,7 +96,6 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@ -1860,8 +1859,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }
}
for _, g := range addNewGroups { for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil { if group := account.GetGroup(g); group != nil {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -185,7 +186,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
} }
if peer.Name != update.Name { peerLabelUpdated := peer.Name != update.Name
if peerLabelUpdated {
peer.Name = update.Name peer.Name = update.Name
existingLabels := account.getPeerDNSLabels() existingLabels := account.getPeerDNSLabels()
@ -226,7 +229,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, err return nil, err
} }
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
return peer, nil return peer, nil
} }
@ -290,6 +296,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
updateAccountPeers := isPeerInActiveGroup(account, peerID)
err = am.deletePeers(ctx, account, []string{peerID}, userID) err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil { if err != nil {
return err return err
@ -300,7 +308,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
@ -390,9 +400,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
var newPeer *nbpeer.Peer var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string var setupKeyID string
var setupKeyName string var setupKeyName string
var ephemeral bool var ephemeral bool
@ -543,6 +553,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("error getting account: %w", err) return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
} }
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(account)
@ -859,51 +873,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false return false
} }
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -1010,3 +979,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
} }
return labelMap return labelMap
} }
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
}

View File

@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
} }
} }
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall. // FirewallRule is a rule of the firewall.
type FirewallRule struct { type FirewallRule struct {
// PeerIP of the peer // PeerIP of the peer
@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
if err = am.savePolicy(account, policy, isUpdate); err != nil { updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if err != nil {
return err return err
} }
@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
} }
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
// savePolicy saves or updates a policy in the given account. // savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules { for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources) rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations) rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
@ -442,18 +457,22 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if isUpdate { if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 { if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
} }
oldPolicy := account.Policies[policyIdx]
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
// Update the existing policy // Update the existing policy
account.Policies[policyIdx] = policyToSave account.Policies[policyIdx] = policyToSave
return nil
return updateAccountPeers, nil
} }
// Add the new policy to the account // Add the new policy to the account
account.Policies = append(account.Policies, policyToSave) account.Policies = append(account.Policies, policyToSave)
return nil return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
} }
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {

View File

@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err return nil, err
} }
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial() account.Network.IncSerial()
@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
}
return nil return nil
} }
@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
} }
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}