mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-13 08:57:28 +02:00
refactor name server groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@ -120,6 +120,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
|
|
||||||
updatedNSGroup := &nbdns.NameServerGroup{
|
updatedNSGroup := &nbdns.NameServerGroup{
|
||||||
ID: nsGroupID,
|
ID: nsGroupID,
|
||||||
|
AccountID: accountID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Primary: req.Primary,
|
Primary: req.Primary,
|
||||||
|
@ -3,7 +3,9 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"slices"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -33,17 +35,18 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
|||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "no permission to create nameserver for this account")
|
||||||
|
}
|
||||||
|
|
||||||
newNSGroup := &nbdns.NameServerGroup{
|
newNSGroup := &nbdns.NameServerGroup{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
Name: name,
|
Name: name,
|
||||||
Description: description,
|
Description: description,
|
||||||
NameServers: nameServerList,
|
NameServers: nameServerList,
|
||||||
@ -54,92 +57,127 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
SearchDomainsEnabled: searchDomainEnabled,
|
SearchDomainsEnabled: searchDomainEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNameServerGroup(false, newNSGroup, account)
|
err = am.validateNameServerGroup(ctx, accountID, newNSGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.NameServerGroups == nil {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup); err != nil {
|
||||||
|
return fmt.Errorf("failed to create nameserver group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
return nil
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
return newNSGroup.Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveNameServerGroup saves nameserver group
|
// SaveNameServerGroup saves nameserver group
|
||||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
if nsGroupToSave == nil {
|
if nsGroupToSave == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
if user.AccountID != accountID {
|
||||||
|
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave)
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave); err != nil {
|
||||||
|
return fmt.Errorf("failed to update nameserver group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup := account.NameServerGroups[nsGroupID]
|
if user.AccountID != accountID {
|
||||||
if nsGroup == nil {
|
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||||
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
|
||||||
}
|
}
|
||||||
delete(account.NameServerGroups, nsGroupID)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.updateAccountPeers(ctx, account)
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete nameserver group: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
|
}
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,32 +195,33 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
|||||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||||
nsGroupID := ""
|
|
||||||
if existingGroup {
|
|
||||||
nsGroupID = nameserverGroup.ID
|
|
||||||
_, found := account.NameServerGroups[nsGroupID]
|
|
||||||
if !found {
|
|
||||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validateNSList(nameserverGroup.NameServers)
|
err = validateNSList(nameserverGroup.NameServers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
nsServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateGroups(nameserverGroup.Groups, groups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -213,14 +252,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
|
||||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, nsGroup := range nsGroupMap {
|
for _, nsGroup := range groups {
|
||||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,14 +267,14 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validateNSList(list []nbdns.NameServer) error {
|
func validateNSList(list []nbdns.NameServer) error {
|
||||||
nsListLenght := len(list)
|
nsListLength := len(list)
|
||||||
if nsListLenght == 0 || nsListLenght > 3 {
|
if nsListLength == 0 || nsListLength > 3 {
|
||||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
func validateGroups(list []string, groups []*nbgroup.Group) error {
|
||||||
if len(list) == 0 {
|
if len(list) == 0 {
|
||||||
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
||||||
}
|
}
|
||||||
@ -244,13 +283,8 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
|||||||
if id == "" {
|
if id == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||||
}
|
}
|
||||||
found := false
|
|
||||||
for groupID := range groups {
|
found := slices.ContainsFunc(groups, func(group *nbgroup.Group) bool { return group.ID == id })
|
||||||
if id == groupID {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
if !found {
|
||||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user