diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index e7a2bc2ae..6e8a3baad 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -120,6 +120,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt updatedNSGroup := &nbdns.NameServerGroup{ ID: nsGroupID, + AccountID: accountID, Name: req.Name, Description: req.Description, Primary: req.Primary, diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 0eb5d9ae4..eecfb3355 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,7 +3,9 @@ package server import ( "context" "errors" + "fmt" "regexp" + "slices" "unicode/utf8" "github.com/miekg/dns" @@ -33,17 +35,18 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "no permission to create nameserver for this account") + } + newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), + AccountID: accountID, Name: name, Description: description, NameServers: nameServerList, @@ -54,92 +57,127 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - err = validateNameServerGroup(false, newNSGroup, account) + err = am.validateNameServerGroup(ctx, accountID, newNSGroup) if err != nil { return nil, err } - if account.NameServerGroups == nil { - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } - account.NameServerGroups[newNSGroup.ID] = newNSGroup + if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup); err != nil { + return fmt.Errorf("failed to create nameserver group: %w", err) + } - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + return nil + }) if err != nil { return nil, err } - am.updateAccountPeers(ctx, account) - am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + return newNSGroup.Copy(), nil } // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if nsGroupToSave == nil { return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = validateNameServerGroup(true, nsGroupToSave, account) + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") + } + + _, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID) if err != nil { return err } - account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave - - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave) if err != nil { return err } - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave); err != nil { + return fmt.Errorf("failed to update nameserver group: %w", err) + } + + return nil + }) + if err != nil { + return err + } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - nsGroup := account.NameServerGroups[nsGroupID] - if nsGroup == nil { - return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") } - delete(account.NameServerGroups, nsGroupID) - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) if err != nil { return err } - am.updateAccountPeers(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil { + return fmt.Errorf("failed to delete nameserver group: %w", err) + } + + return nil + }) + if err != nil { + return err + } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) + return nil } @@ -157,32 +195,33 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } -func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { - nsGroupID := "" - if existingGroup { - nsGroupID = nameserverGroup.ID - _, found := account.NameServerGroups[nsGroupID] - if !found { - return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) - } - } - +func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err } - err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) - if err != nil { - return err - } - err = validateNSList(nameserverGroup.NameServers) if err != nil { return err } - err = validateGroups(nameserverGroup.Groups, account.Groups) + nsServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups) + if err != nil { + return err + } + + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return err + } + + err = validateGroups(nameserverGroup.Groups, groups) if err != nil { return err } @@ -213,14 +252,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo return nil } -func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { +func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } - for _, nsGroup := range nsGroupMap { + for _, nsGroup := range groups { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name) } } @@ -228,14 +267,14 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na } func validateNSList(list []nbdns.NameServer) error { - nsListLenght := len(list) - if nsListLenght == 0 || nsListLenght > 3 { + nsListLength := len(list) + if nsListLength == 0 || nsListLength > 3 { return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) } return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups []*nbgroup.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } @@ -244,13 +283,8 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if id == "" { return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } - found := false - for groupID := range groups { - if id == groupID { - found = true - break - } - } + + found := slices.ContainsFunc(groups, func(group *nbgroup.Group) bool { return group.ID == id }) if !found { return status.Errorf(status.InvalidArgument, "group id %s not found", id) }