mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 01:58:16 +02:00
@ -96,7 +96,6 @@ type AccountManager interface {
|
|||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
|
||||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||||
@ -1860,8 +1859,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
// Propagate changes to peers if group propagation is enabled
|
// Propagate changes to peers if group propagation is enabled
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
|
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
if group := account.GetGroup(g); group != nil {
|
if group := account.GetGroup(g); group != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -185,7 +186,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.Name != update.Name {
|
peerLabelUpdated := peer.Name != update.Name
|
||||||
|
|
||||||
|
if peerLabelUpdated {
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
|
|
||||||
existingLabels := account.getPeerDNSLabels()
|
existingLabels := account.getPeerDNSLabels()
|
||||||
@ -226,7 +229,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||||
|
if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
@ -290,6 +296,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||||
|
|
||||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -300,7 +308,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -390,9 +400,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
|
var groupsToAdd []string
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
var groupsToAdd []string
|
|
||||||
var setupKeyID string
|
var setupKeyID string
|
||||||
var setupKeyName string
|
var setupKeyName string
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
@ -543,6 +553,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
@ -859,51 +873,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey updates peer's public SSH key
|
|
||||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
|
||||||
if sshKey == "" {
|
|
||||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
|
||||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.GetPeer(peerID)
|
|
||||||
if peer == nil {
|
|
||||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer.SSHKey == sshKey {
|
|
||||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.SSHKey = sshKey
|
|
||||||
account.UpdatePeer(peer)
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// trigger network map update
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
@ -1010,3 +979,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
|||||||
}
|
}
|
||||||
return labelMap
|
return labelMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
|
// in an active DNS, route, or ACL configuration.
|
||||||
|
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||||
|
peerGroupIDs := make([]string, 0)
|
||||||
|
for _, group := range account.Groups {
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||||
|
}
|
||||||
|
@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ruleGroups returns a list of all groups referenced in the policy's rules,
|
||||||
|
// including sources and destinations.
|
||||||
|
func (p *Policy) ruleGroups() []string {
|
||||||
|
groups := make([]string, 0)
|
||||||
|
for _, rule := range p.Rules {
|
||||||
|
groups = append(groups, rule.Sources...)
|
||||||
|
groups = append(groups, rule.Destinations...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
// FirewallRule is a rule of the firewall.
|
// FirewallRule is a rule of the firewall.
|
||||||
type FirewallRule struct {
|
type FirewallRule struct {
|
||||||
// PeerIP of the peer
|
// PeerIP of the peer
|
||||||
@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
|
|
||||||
// savePolicy saves or updates a policy in the given account.
|
// savePolicy saves or updates a policy in the given account.
|
||||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||||
for index, rule := range policyToSave.Rules {
|
for index, rule := range policyToSave.Rules {
|
||||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||||
@ -442,18 +457,22 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
|
|||||||
if isUpdate {
|
if isUpdate {
|
||||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||||
if policyIdx < 0 {
|
if policyIdx < 0 {
|
||||||
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldPolicy := account.Policies[policyIdx]
|
||||||
|
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||||
|
|
||||||
// Update the existing policy
|
// Update the existing policy
|
||||||
account.Policies[policyIdx] = policyToSave
|
account.Policies[policyIdx] = policyToSave
|
||||||
return nil
|
|
||||||
|
return updateAccountPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the new policy to the account
|
// Add the new policy to the account
|
||||||
account.Policies = append(account.Policies, policyToSave)
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
|
|
||||||
return nil
|
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||||
|
@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||||
|
|
||||||
@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldRoute := account.Routes[routeToSave.ID]
|
||||||
account.Routes[routeToSave.ID] = routeToSave
|
account.Routes[routeToSave.ID] = routeToSave
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||||
|
|
||||||
@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
|
if isRouteChangeAffectPeers(account, routy) {
|
||||||
am.updateAccountPeers(ctx, account)
|
am.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
|||||||
}
|
}
|
||||||
return &portInfo
|
return &portInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||||
|
// if it has a routing peer, distribution, or peer groups that include peers
|
||||||
|
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||||
|
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user