mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-20 21:08:45 +01:00
[management] Refactor DNS settings to use store methods (#2883)
This commit is contained in:
parent
f118d81d32
commit
0e48a772ff
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
||||
}
|
||||
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range addedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
for _, groupID := range removedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
|
@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
|
||||
}
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save dns settings to store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, policy)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing account dns settings",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing account dns settings",
|
||||
accountID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve dns settings with empty account ID",
|
||||
accountID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, dnsSettings)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dnsSettings)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveDNSSettings(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
|
||||
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, saveDNSSettings, dnsSettings)
|
||||
}
|
||||
|
@ -59,6 +59,7 @@ type Store interface {
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
|
Loading…
Reference in New Issue
Block a user