Refactor name server groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-13 20:41:30 +03:00
parent 4b943c34b7
commit 218345e0ff
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 187 additions and 82 deletions

View File

@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Description: description,
NameServers: nameServerList,
@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
err = validateNameServerGroup(false, newNSGroup, account)
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
})
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.anyGroupHasPeers(account, newNSGroup.Groups) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil
}
@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
})
if err != nil {
return err
}
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
nsGroup := account.NameServerGroups[nsGroupID]
if nsGroup == nil {
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err
}
if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil
}
@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
return nil
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
}
}
@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 3 {
nsListLength := len(list)
if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
}
return nil
@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
if _, found := groups[id]; !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
}
@ -277,11 +340,3 @@ func validateDomain(domain string) error {
return nil
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@ -1489,12 +1489,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}
return nsGroups, nil
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
}
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
}
return nsGroup, nil
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}
if result.RowsAffected == 0 {
return status.NewNameServerGroupNotFoundError(nsGroupID)
}
return nil
}
// getRecords retrieves records from the database based on the account ID.

View File

@ -149,3 +149,8 @@ func NewPostureChecksNotFoundError(postureChecksID string) error {
func NewPolicyNotFoundError(policyID string) error {
return Errorf(NotFound, "policy: %s not found", policyID)
}
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
func NewNameServerGroupNotFoundError(nsGroupID string) error {
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
}

View File

@ -117,6 +117,8 @@ type Store interface {
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error