netbird/management/server/group.go
bcmmbaga 78e238646c
refactor groups methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 16:32:31 +03:00

556 lines
17 KiB
Go

package server
import (
"context"
"errors"
"fmt"
"slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
)
type GroupLinkError struct {
Resource string
Name string
}
func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
return nil
}
// GetGroup returns a specific group by groupID in an account
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
}
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to create group")
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
}
for _, peerID := range addedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
for _, peerID := range removedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
return eventsToStore
}
// difference returns the elements in `a` that aren't in `b`.
func difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete group")
}
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
if err != nil {
return err
}
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, groupID, accountID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return nil
}
// DeleteGroups deletes groups from an account.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete groups")
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
for _, groupID := range groupIDs {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
if err != nil {
continue
}
if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(ctx, account)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
am.updateAccountPeers(ctx, account)
return nil
}
func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.ID, group.AccountID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if settings.Extra != nil {
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, groupID string, accountID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, groupID string, accountID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, groupID string, accountID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, groupID string, accountID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, groupID string, accountID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}